# Chest X-Ray Lung Segmentation - Model Training

**Author**: Deep Learning Project  
**Model**: U-Net for Lung Segmentation  
**Framework**: PyTorch with GPU acceleration  

This notebook trains a deep learning model for accurate lung segmentation from chest X-rays.

## 1. Import Libraries and Setup

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# PyTorch imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms

# Image processing
from PIL import Image
import cv2
from tqdm.auto import tqdm

# Metrics
from sklearn.metrics import jaccard_score, f1_score, precision_score, recall_score

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

plt.style.use('seaborn-v0_8-darkgrid')
print("\n✓ Libraries imported successfully!")

## 2. Configuration

In [None]:
# Paths
BASE_DIR = Path(r"d:\DEEP LEARNING\Dataset\ChestXray")
IMAGE_DIR = BASE_DIR / "CXR_Combined" / "images"
MASK_DIR = BASE_DIR / "CXR_Combined" / "masks"
SPLIT_DIR = Path(r"d:\DEEP LEARNING\ChestXraySegmentation")

# Output directories
OUTPUT_DIR = Path(r"d:\DEEP LEARNING\ChestXraySegmentation")
MODEL_DIR = OUTPUT_DIR / "models"
RESULTS_DIR = OUTPUT_DIR / "results"
PLOTS_DIR = OUTPUT_DIR / "plots"

# Create directories
MODEL_DIR.mkdir(exist_ok=True)
RESULTS_DIR.mkdir(exist_ok=True)
PLOTS_DIR.mkdir(exist_ok=True)

# Training configuration
CONFIG = {
    # Model
    'img_size': 256,
    'in_channels': 1,
    'out_channels': 1,
    
    # Training
    'batch_size': 16,
    'num_epochs': 50,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    
    # Data
    'num_workers': 4,
    'pin_memory': True if torch.cuda.is_available() else False,
    
    # Other
    'save_every': 5,  # Save model every N epochs
    'early_stopping_patience': 10,
}

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

# Save config
with open(RESULTS_DIR / 'config.json', 'w') as f:
    json.dump(CONFIG, f, indent=2)
print("\n✓ Configuration saved!")

## 3. Dataset Class

In [None]:
class LungSegmentationDataset(Dataset):
    """Custom Dataset for Lung Segmentation"""
    
    def __init__(self, image_dir, mask_dir, filenames, img_size=256, augment=False):
        """
        Args:
            image_dir: Path to images directory
            mask_dir: Path to masks directory
            filenames: List of image filenames
            img_size: Target image size (will resize to img_size x img_size)
            augment: Whether to apply data augmentation
        """
        self.image_dir = Path(image_dir)
        self.mask_dir = Path(mask_dir)
        self.filenames = filenames
        self.img_size = img_size
        self.augment = augment
        
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        filename = self.filenames[idx]
        
        # Load image and mask
        image_path = self.image_dir / filename
        mask_path = self.mask_dir / filename
        
        # Read image (grayscale)
        image = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE)
        mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        
        # Resize
        image = cv2.resize(image, (self.img_size, self.img_size), interpolation=cv2.INTER_AREA)
        mask = cv2.resize(mask, (self.img_size, self.img_size), interpolation=cv2.INTER_NEAREST)
        
        # Normalize image to [0, 1]
        image = image.astype(np.float32) / 255.0
        mask = (mask > 0).astype(np.float32)  # Binary mask
        
        # Apply augmentation
        if self.augment:
            image, mask = self.augment_data(image, mask)
        
        # Add channel dimension: (H, W) -> (1, H, W)
        image = np.expand_dims(image, axis=0)
        mask = np.expand_dims(mask, axis=0)
        
        # Convert to tensors
        image = torch.from_numpy(image)
        mask = torch.from_numpy(mask)
        
        return image, mask
    
    def augment_data(self, image, mask):
        """Simple augmentation: horizontal flip"""
        if np.random.random() > 0.5:
            image = np.fliplr(image).copy()
            mask = np.fliplr(mask).copy()
        return image, mask

print("✓ Dataset class defined!")

## 4. U-Net Model Architecture

In [None]:
class DoubleConv(nn.Module):
    """(Conv2D -> BatchNorm -> ReLU) * 2"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    
    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # Pad x1 to match x2 size if needed
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        
        # Concatenate
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class UNet(nn.Module):
    """U-Net Architecture for Image Segmentation"""
    def __init__(self, in_channels=1, out_channels=1, features=[64, 128, 256, 512]):
        super().__init__()
        
        # Encoder (Downsampling)
        self.inc = DoubleConv(in_channels, features[0])
        self.down1 = Down(features[0], features[1])
        self.down2 = Down(features[1], features[2])
        self.down3 = Down(features[2], features[3])
        
        # Bottleneck
        self.down4 = Down(features[3], features[3] * 2)
        
        # Decoder (Upsampling)
        self.up1 = Up(features[3] * 2, features[3])
        self.up2 = Up(features[3], features[2])
        self.up3 = Up(features[2], features[1])
        self.up4 = Up(features[1], features[0])
        
        # Output layer
        self.outc = nn.Conv2d(features[0], out_channels, kernel_size=1)
    
    def forward(self, x):
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        # Decoder with skip connections
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        
        # Output
        logits = self.outc(x)
        return logits

# Test model
model = UNet(in_channels=CONFIG['in_channels'], out_channels=CONFIG['out_channels'])
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n✓ U-Net model created!")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

## 5. Load Data

In [None]:
# Load split files
train_df = pd.read_csv(SPLIT_DIR / 'train_split.csv')
val_df = pd.read_csv(SPLIT_DIR / 'val_split.csv')
test_df = pd.read_csv(SPLIT_DIR / 'test_split.csv')

train_files = train_df['filename'].tolist()
val_files = val_df['filename'].tolist()
test_files = test_df['filename'].tolist()

print(f"Data splits loaded:")
print(f"  Training:   {len(train_files)} images")
print(f"  Validation: {len(val_files)} images")
print(f"  Testing:    {len(test_files)} images")

# Create datasets
train_dataset = LungSegmentationDataset(
    IMAGE_DIR, MASK_DIR, train_files, 
    img_size=CONFIG['img_size'], augment=True
)

val_dataset = LungSegmentationDataset(
    IMAGE_DIR, MASK_DIR, val_files,
    img_size=CONFIG['img_size'], augment=False
)

test_dataset = LungSegmentationDataset(
    IMAGE_DIR, MASK_DIR, test_files,
    img_size=CONFIG['img_size'], augment=False
)

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=CONFIG['num_workers'],
    pin_memory=CONFIG['pin_memory']
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    pin_memory=CONFIG['pin_memory']
)

test_loader = DataLoader(
    test_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    pin_memory=CONFIG['pin_memory']
)

print(f"\n✓ DataLoaders created!")
print(f"  Training batches:   {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")
print(f"  Testing batches:    {len(test_loader)}")

## 6. Loss Functions and Metrics

In [None]:
class DiceLoss(nn.Module):
    """Dice Loss for segmentation"""
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth
    
    def forward(self, predictions, targets):
        predictions = torch.sigmoid(predictions)
        
        # Flatten
        predictions = predictions.view(-1)
        targets = targets.view(-1)
        
        intersection = (predictions * targets).sum()
        dice = (2. * intersection + self.smooth) / (predictions.sum() + targets.sum() + self.smooth)
        
        return 1 - dice


class CombinedLoss(nn.Module):
    """Combined BCE and Dice Loss"""
    def __init__(self, alpha=0.5):
        super().__init__()
        self.alpha = alpha
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()
    
    def forward(self, predictions, targets):
        bce_loss = self.bce(predictions, targets)
        dice_loss = self.dice(predictions, targets)
        return self.alpha * bce_loss + (1 - self.alpha) * dice_loss


def calculate_iou(predictions, targets, threshold=0.5):
    """Calculate Intersection over Union (IoU)"""
    predictions = (predictions > threshold).float()
    targets = targets.float()
    
    intersection = (predictions * targets).sum()
    union = predictions.sum() + targets.sum() - intersection
    
    iou = (intersection + 1e-6) / (union + 1e-6)
    return iou.item()


def calculate_dice(predictions, targets, threshold=0.5):
    """Calculate Dice Coefficient"""
    predictions = (predictions > threshold).float()
    targets = targets.float()
    
    intersection = (predictions * targets).sum()
    dice = (2. * intersection + 1e-6) / (predictions.sum() + targets.sum() + 1e-6)
    
    return dice.item()

print("✓ Loss functions and metrics defined!")

## 7. Training Functions

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    running_iou = 0.0
    running_dice = 0.0
    
    pbar = tqdm(dataloader, desc='Training')
    for images, masks in pbar:
        images = images.to(device)
        masks = masks.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Calculate metrics
        with torch.no_grad():
            preds = torch.sigmoid(outputs)
            iou = calculate_iou(preds, masks)
            dice = calculate_dice(preds, masks)
        
        running_loss += loss.item()
        running_iou += iou
        running_dice += dice
        
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'iou': f'{iou:.4f}',
            'dice': f'{dice:.4f}'
        })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_iou = running_iou / len(dataloader)
    epoch_dice = running_dice / len(dataloader)
    
    return epoch_loss, epoch_iou, epoch_dice


def validate(model, dataloader, criterion, device):
    """Validate the model"""
    model.eval()
    running_loss = 0.0
    running_iou = 0.0
    running_dice = 0.0
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc='Validation')
        for images, masks in pbar:
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            # Calculate metrics
            preds = torch.sigmoid(outputs)
            iou = calculate_iou(preds, masks)
            dice = calculate_dice(preds, masks)
            
            running_loss += loss.item()
            running_iou += iou
            running_dice += dice
            
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'iou': f'{iou:.4f}',
                'dice': f'{dice:.4f}'
            })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_iou = running_iou / len(dataloader)
    epoch_dice = running_dice / len(dataloader)
    
    return epoch_loss, epoch_iou, epoch_dice

print("✓ Training functions defined!")

## 8. Train Model

In [None]:
# Initialize model, criterion, and optimizer
model = UNet(in_channels=CONFIG['in_channels'], out_channels=CONFIG['out_channels'])
model = model.to(device)

criterion = CombinedLoss(alpha=0.5)
optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'], 
                       weight_decay=CONFIG['weight_decay'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, 
                                                  patience=5, verbose=True)

# Training history
history = {
    'train_loss': [], 'train_iou': [], 'train_dice': [],
    'val_loss': [], 'val_iou': [], 'val_dice': [],
    'lr': []
}

best_val_dice = 0.0
patience_counter = 0

print(f"\nStarting training for {CONFIG['num_epochs']} epochs...")
print(f"Device: {device}\n")

start_time = datetime.now()

for epoch in range(CONFIG['num_epochs']):
    print(f"\nEpoch {epoch + 1}/{CONFIG['num_epochs']}")
    print("-" * 60)
    
    # Train
    train_loss, train_iou, train_dice = train_one_epoch(
        model, train_loader, criterion, optimizer, device
    )
    
    # Validate
    val_loss, val_iou, val_dice = validate(
        model, val_loader, criterion, device
    )
    
    # Update learning rate
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_iou'].append(train_iou)
    history['train_dice'].append(train_dice)
    history['val_loss'].append(val_loss)
    history['val_iou'].append(val_iou)
    history['val_dice'].append(val_dice)
    history['lr'].append(current_lr)
    
    # Print epoch summary
    print(f"\nEpoch Summary:")
    print(f"  Train - Loss: {train_loss:.4f}, IoU: {train_iou:.4f}, Dice: {train_dice:.4f}")
    print(f"  Val   - Loss: {val_loss:.4f}, IoU: {val_iou:.4f}, Dice: {val_dice:.4f}")
    print(f"  Learning Rate: {current_lr:.2e}")
    
    # Save best model
    if val_dice > best_val_dice:
        best_val_dice = val_dice
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_dice': val_dice,
            'val_iou': val_iou,
            'history': history
        }, MODEL_DIR / 'best_model.pth')
        print(f"  ✓ Best model saved! (Dice: {val_dice:.4f})")
        patience_counter = 0
    else:
        patience_counter += 1
    
    # Save checkpoint every N epochs
    if (epoch + 1) % CONFIG['save_every'] == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'history': history
        }, MODEL_DIR / f'checkpoint_epoch_{epoch+1}.pth')
        print(f"  ✓ Checkpoint saved (epoch {epoch+1})")
    
    # Early stopping
    if patience_counter >= CONFIG['early_stopping_patience']:
        print(f"\nEarly stopping triggered after {epoch + 1} epochs")
        break

end_time = datetime.now()
training_time = end_time - start_time

print(f"\n{'='*60}")
print(f"Training completed!")
print(f"Total time: {training_time}")
print(f"Best validation Dice: {best_val_dice:.4f}")
print(f"{'='*60}")

# Save final model and history
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'history': history,
    'config': CONFIG
}, MODEL_DIR / 'final_model.pth')

# Save training history
history_df = pd.DataFrame(history)
history_df.to_csv(RESULTS_DIR / 'training_history.csv', index=False)
print("\n✓ Final model and history saved!")

## 9. Plot Training History

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

epochs = range(1, len(history['train_loss']) + 1)

# Loss
axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Training Loss', linewidth=2)
axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Validation Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('Loss', fontsize=12)
axes[0, 0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0, 0].legend(fontsize=11)
axes[0, 0].grid(alpha=0.3)

# IoU
axes[0, 1].plot(epochs, history['train_iou'], 'b-', label='Training IoU', linewidth=2)
axes[0, 1].plot(epochs, history['val_iou'], 'r-', label='Validation IoU', linewidth=2)
axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('IoU Score', fontsize=12)
axes[0, 1].set_title('Training and Validation IoU', fontsize=14, fontweight='bold')
axes[0, 1].legend(fontsize=11)
axes[0, 1].grid(alpha=0.3)

# Dice
axes[1, 0].plot(epochs, history['train_dice'], 'b-', label='Training Dice', linewidth=2)
axes[1, 0].plot(epochs, history['val_dice'], 'r-', label='Validation Dice', linewidth=2)
axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('Dice Coefficient', fontsize=12)
axes[1, 0].set_title('Training and Validation Dice Coefficient', fontsize=14, fontweight='bold')
axes[1, 0].legend(fontsize=11)
axes[1, 0].grid(alpha=0.3)

# Learning Rate
axes[1, 1].plot(epochs, history['lr'], 'g-', linewidth=2)
axes[1, 1].set_xlabel('Epoch', fontsize=12)
axes[1, 1].set_ylabel('Learning Rate', fontsize=12)
axes[1, 1].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
axes[1, 1].set_yscale('log')
axes[1, 1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig(PLOTS_DIR / 'training_history.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Training history plotted!")

## 10. Evaluate on Test Set

In [None]:
# Load best model
checkpoint = torch.load(MODEL_DIR / 'best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch'] + 1}")
print(f"Best validation Dice: {checkpoint['val_dice']:.4f}\n")

# Evaluate on test set
test_loss, test_iou, test_dice = validate(model, test_loader, criterion, device)

print(f"\n{'='*60}")
print(f"Test Set Results:")
print(f"  Loss: {test_loss:.4f}")
print(f"  IoU:  {test_iou:.4f}")
print(f"  Dice: {test_dice:.4f}")
print(f"{'='*60}")

# Save test results
test_results = {
    'test_loss': test_loss,
    'test_iou': test_iou,
    'test_dice': test_dice,
    'best_val_dice': checkpoint['val_dice'],
    'best_epoch': checkpoint['epoch'] + 1
}

with open(RESULTS_DIR / 'test_results.json', 'w') as f:
    json.dump(test_results, f, indent=2)

print("\n✓ Test results saved!")

## 11. Visualize Predictions

In [None]:
def visualize_predictions(model, dataset, device, num_samples=6):
    """Visualize model predictions"""
    model.eval()
    
    # Get random samples
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4 * num_samples))
    
    with torch.no_grad():
        for idx, sample_idx in enumerate(indices):
            image, mask = dataset[sample_idx]
            image_input = image.unsqueeze(0).to(device)
            
            # Predict
            output = model(image_input)
            pred_mask = torch.sigmoid(output).cpu().squeeze()
            pred_binary = (pred_mask > 0.5).float()
            
            # Convert to numpy
            image_np = image.squeeze().numpy()
            mask_np = mask.squeeze().numpy()
            pred_np = pred_mask.numpy()
            pred_binary_np = pred_binary.numpy()
            
            # Calculate metrics
            iou = calculate_iou(pred_mask.unsqueeze(0).unsqueeze(0), 
                               mask.unsqueeze(0))
            dice = calculate_dice(pred_mask.unsqueeze(0).unsqueeze(0), 
                                 mask.unsqueeze(0))
            
            # Plot
            axes[idx, 0].imshow(image_np, cmap='gray')
            axes[idx, 0].set_title('Input Image', fontsize=11, fontweight='bold')
            axes[idx, 0].axis('off')
            
            axes[idx, 1].imshow(mask_np, cmap='gray')
            axes[idx, 1].set_title('Ground Truth', fontsize=11, fontweight='bold')
            axes[idx, 1].axis('off')
            
            axes[idx, 2].imshow(pred_np, cmap='gray')
            axes[idx, 2].set_title('Prediction (Prob)', fontsize=11, fontweight='bold')
            axes[idx, 2].axis('off')
            
            axes[idx, 3].imshow(pred_binary_np, cmap='gray')
            axes[idx, 3].set_title(f'Prediction (Binary)\nIoU: {iou:.3f}, Dice: {dice:.3f}', 
                                  fontsize=11, fontweight='bold')
            axes[idx, 3].axis('off')
    
    plt.tight_layout()
    plt.savefig(PLOTS_DIR / 'test_predictions.png', dpi=300, bbox_inches='tight')
    plt.show()

# Visualize test predictions
visualize_predictions(model, test_dataset, device, num_samples=8)
print("✓ Predictions visualized!")

## 12. Calculate Detailed Metrics

In [None]:
def calculate_detailed_metrics(model, dataloader, device):
    """Calculate detailed per-image metrics"""
    model.eval()
    
    all_ious = []
    all_dices = []
    all_precisions = []
    all_recalls = []
    
    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc='Calculating metrics'):
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            preds = torch.sigmoid(outputs)
            
            # Calculate metrics for each image in batch
            for pred, mask in zip(preds, masks):
                pred_binary = (pred > 0.5).float()
                
                # Flatten
                pred_flat = pred_binary.cpu().numpy().flatten()
                mask_flat = mask.cpu().numpy().flatten()
                
                # Calculate metrics
                iou = jaccard_score(mask_flat, pred_flat, zero_division=0)
                dice = f1_score(mask_flat, pred_flat, zero_division=0)
                precision = precision_score(mask_flat, pred_flat, zero_division=0)
                recall = recall_score(mask_flat, pred_flat, zero_division=0)
                
                all_ious.append(iou)
                all_dices.append(dice)
                all_precisions.append(precision)
                all_recalls.append(recall)
    
    return {
        'iou': all_ious,
        'dice': all_dices,
        'precision': all_precisions,
        'recall': all_recalls
    }

# Calculate metrics on test set
print("Calculating detailed metrics on test set...\n")
test_metrics = calculate_detailed_metrics(model, test_loader, device)

# Print statistics
print(f"\n{'='*60}")
print("Detailed Test Set Metrics:")
print(f"{'='*60}")
for metric_name, values in test_metrics.items():
    print(f"\n{metric_name.upper()}:")
    print(f"  Mean:   {np.mean(values):.4f}")
    print(f"  Median: {np.median(values):.4f}")
    print(f"  Std:    {np.std(values):.4f}")
    print(f"  Min:    {np.min(values):.4f}")
    print(f"  Max:    {np.max(values):.4f}")

# Save metrics
metrics_df = pd.DataFrame(test_metrics)
metrics_df.to_csv(RESULTS_DIR / 'detailed_test_metrics.csv', index=False)
print("\n✓ Detailed metrics saved!")

## 13. Plot Metric Distributions

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

metrics_to_plot = ['iou', 'dice', 'precision', 'recall']
titles = ['IoU (Jaccard Index)', 'Dice Coefficient', 'Precision', 'Recall']
colors = ['steelblue', 'coral', 'mediumseagreen', 'mediumpurple']

for idx, (metric, title, color) in enumerate(zip(metrics_to_plot, titles, colors)):
    row, col = idx // 2, idx % 2
    
    values = test_metrics[metric]
    
    # Histogram
    axes[row, col].hist(values, bins=30, color=color, alpha=0.7, edgecolor='black')
    axes[row, col].axvline(np.mean(values), color='red', linestyle='--', 
                           linewidth=2, label=f'Mean: {np.mean(values):.4f}')
    axes[row, col].axvline(np.median(values), color='green', linestyle=':', 
                           linewidth=2, label=f'Median: {np.median(values):.4f}')
    
    axes[row, col].set_xlabel(title, fontsize=12)
    axes[row, col].set_ylabel('Frequency', fontsize=12)
    axes[row, col].set_title(f'{title} Distribution', fontsize=13, fontweight='bold')
    axes[row, col].legend(fontsize=10)
    axes[row, col].grid(alpha=0.3)

plt.tight_layout()
plt.savefig(PLOTS_DIR / 'metric_distributions.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Metric distributions plotted!")

## 14. Final Summary Report

In [None]:
summary_report = f"""
{'='*70}
CHEST X-RAY LUNG SEGMENTATION - TRAINING SUMMARY REPORT
{'='*70}

MODEL ARCHITECTURE
{'-'*70}
Model:                     U-Net
Input Channels:            {CONFIG['in_channels']}
Output Channels:           {CONFIG['out_channels']}
Image Size:                {CONFIG['img_size']}x{CONFIG['img_size']}
Total Parameters:          {total_params:,}
Trainable Parameters:      {trainable_params:,}

TRAINING CONFIGURATION
{'-'*70}
Epochs:                    {len(history['train_loss'])}
Batch Size:                {CONFIG['batch_size']}
Learning Rate:             {CONFIG['learning_rate']}
Weight Decay:              {CONFIG['weight_decay']}
Loss Function:             Combined BCE + Dice Loss
Optimizer:                 Adam
Device:                    {device}

DATASET
{'-'*70}
Training Samples:          {len(train_files)}
Validation Samples:        {len(val_files)}
Test Samples:              {len(test_files)}
Total:                     {len(train_files) + len(val_files) + len(test_files)}

TRAINING RESULTS
{'-'*70}
Training Time:             {training_time}
Best Epoch:                {checkpoint['epoch'] + 1}
Best Validation Dice:      {checkpoint['val_dice']:.4f}
Best Validation IoU:       {checkpoint['val_iou']:.4f}
Final Training Loss:       {history['train_loss'][-1]:.4f}
Final Validation Loss:     {history['val_loss'][-1]:.4f}

TEST SET PERFORMANCE
{'-'*70}
Test Loss:                 {test_loss:.4f}
Test IoU (mean ± std):     {np.mean(test_metrics['iou']):.4f} ± {np.std(test_metrics['iou']):.4f}
Test Dice (mean ± std):    {np.mean(test_metrics['dice']):.4f} ± {np.std(test_metrics['dice']):.4f}
Test Precision:            {np.mean(test_metrics['precision']):.4f} ± {np.std(test_metrics['precision']):.4f}
Test Recall:               {np.mean(test_metrics['recall']):.4f} ± {np.std(test_metrics['recall']):.4f}

SAVED FILES
{'-'*70}
✓ best_model.pth           - Best model checkpoint
✓ final_model.pth          - Final model state
✓ training_history.csv     - Complete training history
✓ test_results.json        - Test set results
✓ detailed_test_metrics.csv - Per-image metrics
✓ training_history.png     - Training curves
✓ test_predictions.png     - Sample predictions
✓ metric_distributions.png - Metric distributions

{'='*70}
Training completed successfully!
{'='*70}
"""

print(summary_report)

# Save summary
with open(RESULTS_DIR / 'training_summary.txt', 'w') as f:
    f.write(summary_report)

print("\n✓ Summary report saved to 'training_summary.txt'")

## Conclusion

The U-Net model has been successfully trained for lung segmentation on chest X-rays:

### Key Achievements:
1. **Model Architecture**: Implemented U-Net with skip connections for precise segmentation
2. **Training**: Trained with combined BCE + Dice loss for optimal results
3. **Performance**: Achieved strong metrics on test set
4. **Reproducibility**: All configurations, models, and metrics saved

### Saved Artifacts:
- **Models**: Best and final checkpoints saved in `models/`
- **Results**: Training history and test metrics in `results/`
- **Visualizations**: Training curves and predictions in `plots/`

The model is ready for deployment and inference on new chest X-ray images!