In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

import cv2
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import rotate
import elasticdeform
import os
import math
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


In [None]:
class DoubleConv(nn.Module):
    """Double convolution block as in original paper"""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Encoder block with maxpool + double conv"""
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    
    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Decoder block with upconv + skip connection + double conv"""
    def __init__(self, in_channels, out_channels):
        super(Up, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # Handle input size differences
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        
        # Concatenate skip connection
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


In [None]:
class UNet(nn.Module):
    """U-Net architecture exactly as in original paper"""
    def __init__(self, n_channels=1, n_classes=1):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        
        # Encoder (Contracting Path)
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        
        # Decoder (Expansive Path)
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        
        # Output layer
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        # Decoder with skip connections
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        
        # Output
        logits = self.outc(x)
        return self.sigmoid(logits)

# Initialize model and check parameters
model = UNet(n_channels=1, n_classes=1).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f'Total parameters: {total_params:,}')

# Test model with dummy input
dummy_input = torch.randn(1, 1, 512, 512).to(device)
with torch.no_grad():
    output = model(dummy_input)
    print(f'Input shape: {dummy_input.shape}')
    print(f'Output shape: {output.shape}')


In [None]:
class WeightedBCELoss(nn.Module):
    """Weighted Binary Cross-Entropy for class imbalance"""
    def __init__(self, pos_weight=1.0):
        super(WeightedBCELoss, self).__init__()
        self.pos_weight = pos_weight
    
    def forward(self, inputs, targets):
        # Clip predictions to prevent log(0)
        inputs = torch.clamp(inputs, 1e-7, 1.0 - 1e-7)
        
        # Calculate weighted loss
        loss = -(self.pos_weight * targets * torch.log(inputs) + 
                (1 - targets) * torch.log(1 - inputs))
        return torch.mean(loss)

class DiceLoss(nn.Module):
    """Dice loss for better segmentation performance"""
    def __init__(self, smooth=1):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, inputs, targets):
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()
        dice = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)
        
        return 1 - dice

class CombinedLoss(nn.Module):
    """Combine weighted BCE and Dice loss"""
    def __init__(self, pos_weight=2.0, smooth=1):
        super(CombinedLoss, self).__init__()
        self.bce_loss = WeightedBCELoss(pos_weight)
        self.dice_loss = DiceLoss(smooth)
    
    def forward(self, inputs, targets):
        bce = self.bce_loss(inputs, targets)
        dice = self.dice_loss(inputs, targets)
        return bce + dice

# Test loss functions
criterion = CombinedLoss(pos_weight=2.0)
dummy_pred = torch.sigmoid(torch.randn(1, 1, 512, 512))
dummy_target = torch.randint(0, 2, (1, 1, 512, 512)).float()

test_loss = criterion(dummy_pred, dummy_target)
print(f'Test loss value: {test_loss.item():.6f}')


In [None]:
class CellSegmentationDataset(Dataset):
    """Custom dataset for cell segmentation"""
    
    def __init__(self, image_paths, mask_paths, transform=None, augment=True):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform
        self.augment = augment
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image and mask
        image = cv2.imread(self.image_paths[idx], cv2.IMREAD_GRAYSCALE)
        mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)
        
        # Handle case where files might not exist
        if image is None or mask is None:
            # Create dummy data for demonstration
            image = np.random.randint(0, 255, (512, 512), dtype=np.uint8)
            mask = np.random.randint(0, 2, (512, 512), dtype=np.uint8) * 255
        
        # Normalize to [0,1]
        image = image.astype(np.float32) / 255.0
        mask = mask.astype(np.float32) / 255.0
        
        # Resize to 512x512
        image = cv2.resize(image, (512, 512))
        mask = cv2.resize(mask, (512, 512))
        
        # Apply augmentations
        if self.augment:
            image, mask = self.augment_data(image, mask)
        
        # Convert to torch tensors
        image = torch.from_numpy(image).unsqueeze(0)  # Add channel dimension
        mask = torch.from_numpy(mask).unsqueeze(0)
        
        return image, mask
    
    def elastic_transform(self, image, mask, alpha=1000, sigma=50):
        """Apply elastic deformation as in original paper"""
        try:
            # Create displacement fields
            displacement = np.random.randn(2, *image.shape) * sigma
            
            # Apply deformation to both image and mask
            image_deformed = elasticdeform.deform_grid(image, displacement, alpha=alpha)
            mask_deformed = elasticdeform.deform_grid(mask, displacement, alpha=alpha)
            
            return image_deformed, mask_deformed
        except:
            # Fallback if elasticdeform fails
            return image, mask
    
    def augment_data(self, image, mask):
        """Complete augmentation pipeline"""
        # Random rotation
        if np.random.random() > 0.5:
            angle = np.random.uniform(-180, 180)
            image = rotate(image, angle, reshape=False)
            mask = rotate(mask, angle, reshape=False)
        
        # Random flip
        if np.random.random() > 0.5:
            image = np.fliplr(image)
            mask = np.fliplr(mask)
        
        # Elastic deformation (most important for cells)
        if np.random.random() > 0.5:
            image, mask = self.elastic_transform(image, mask)
        
        # Intensity variations
        if np.random.random() > 0.5:
            image = image * np.random.uniform(0.8, 1.2)
            image = np.clip(image, 0, 1)
        
        return image, mask

# Create dummy dataset for demonstration
def create_dummy_data_paths(num_samples=100):
    """Create dummy data paths for demonstration"""
    image_paths = [f'dummy_image_{i}.png' for i in range(num_samples)]
    mask_paths = [f'dummy_mask_{i}.png' for i in range(num_samples)]
    return image_paths, mask_paths

# Create datasets
train_images, train_masks = create_dummy_data_paths(80)
val_images, val_masks = create_dummy_data_paths(20)

train_dataset = CellSegmentationDataset(train_images, train_masks, augment=True)
val_dataset = CellSegmentationDataset(val_images, val_masks, augment=False)

print(f'Training samples: {len(train_dataset)}')
print(f'Validation samples: {len(val_dataset)}')

# Test dataset
sample_image, sample_mask = train_dataset[0]
print(f'Sample image shape: {sample_image.shape}')
print(f'Sample mask shape: {sample_mask.shape}')


In [None]:
def dice_coefficient(pred, target, smooth=1):
    """Calculate Dice coefficient"""
    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    dice = (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
    return dice

def iou_score(pred, target, smooth=1):
    """Calculate IoU score"""
    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    iou = (intersection + smooth) / (union + smooth)
    return iou

def evaluate_model(model, test_loader, device):
    """Evaluate model on test set"""
    model.eval()
    dice_scores = []
    iou_scores = []
    
    with torch.no_grad():
        for images, masks in test_loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            
            # Apply threshold
            predictions = (outputs > 0.5).float()
            
            dice = dice_coefficient(predictions, masks)
            iou = iou_score(predictions, masks)
            
            dice_scores.append(dice.item())
            iou_scores.append(iou.item())
    
    return np.mean(dice_scores), np.mean(iou_scores)

# Test metrics with dummy data
dummy_pred = torch.sigmoid(torch.randn(2, 1, 512, 512))
dummy_target = torch.randint(0, 2, (2, 1, 512, 512)).float()

test_dice = dice_coefficient(dummy_pred, dummy_target)
test_iou = iou_score(dummy_pred, dummy_target)

print(f'Test Dice coefficient: {test_dice.item():.4f}')
print(f'Test IoU score: {test_iou.item():.4f}')


In [None]:
def train_unet(model, train_loader, val_loader, device, num_epochs=50):
    """Training loop for U-Net"""
    
    # Loss function and optimizer
    criterion = CombinedLoss(pos_weight=2.0)
    optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.99, weight_decay=0.0005)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, min_lr=1e-7)
    
    # Training tracking
    best_val_loss = float('inf')
    patience = 20
    patience_counter = 0
    
    train_losses = []
    val_losses = []
    train_dice_scores = []
    val_dice_scores = []
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_dice = 0.0
        
        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
        for batch_idx, (images, masks) in enumerate(train_pbar):
            images, masks = images.to(device), masks.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            # Calculate Dice score
            with torch.no_grad():
                dice = dice_coefficient(outputs, masks)
                train_dice += dice.item()
            
            # Update progress bar
            train_pbar.set_postfix({
                'Loss': f'{loss.item():.6f}',
                'Dice': f'{dice.item():.4f}'
            })
        
        avg_train_loss = train_loss / len(train_loader)
        avg_train_dice = train_dice / len(train_loader)
        train_losses.append(avg_train_loss)
        train_dice_scores.append(avg_train_dice)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_dice = 0.0
        
        val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]')
        with torch.no_grad():
            for images, masks in val_pbar:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()
                
                # Calculate Dice score
                dice = dice_coefficient(outputs, masks)
                val_dice += dice.item()
                
                # Update progress bar
                val_pbar.set_postfix({
                    'Loss': f'{loss.item():.6f}',
                    'Dice': f'{dice.item():.4f}'
                })
        
        avg_val_loss = val_loss / len(val_loader)
        avg_val_dice = val_dice / len(val_loader)
        val_losses.append(avg_val_loss)
        val_dice_scores.append(avg_val_dice)
        
        print(f'\nEpoch {epoch+1}/{num_epochs}:')
        print(f'  Train Loss: {avg_train_loss:.6f}, Train Dice: {avg_train_dice:.4f}')
        print(f'  Val Loss: {avg_val_loss:.6f}, Val Dice: {avg_val_dice:.4f}')
        print(f'  LR: {optimizer.param_groups[0]["lr"]:.2e}')
        
        # Learning rate scheduling
        scheduler.step(avg_val_loss)
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_unet_cell_seg.pth')
            patience_counter = 0
            print(f'  -> New best model saved! (Val Loss: {best_val_loss:.6f})')
        else:
            patience_counter += 1
        
        # Early stopping
        if patience_counter >= patience:
            print(f'Early stopping at epoch {epoch+1}')
            break
    
    return {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_dice_scores': train_dice_scores,
        'val_dice_scores': val_dice_scores
    }


In [None]:
# Create data loaders
batch_size = 2  # Small batch size due to memory constraints
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print(f'Training batches: {len(train_loader)}')
print(f'Validation batches: {len(val_loader)}')

# Initialize fresh model
model = UNet(n_channels=1, n_classes=1).to(device)

# Train the model
print("Starting training...")
history = train_unet(model, train_loader, val_loader, device, num_epochs=20)

print("\nTraining completed!")


In [None]:
def plot_training_history(history):
    """Plot training and validation metrics"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot losses
    ax1.plot(history['train_losses'], label='Train Loss', color='blue')
    ax1.plot(history['val_losses'], label='Val Loss', color='red')
    ax1.set_title('Training and Validation Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)
    
    # Plot Dice scores
    ax2.plot(history['train_dice_scores'], label='Train Dice', color='blue')
    ax2.plot(history['val_dice_scores'], label='Val Dice', color='red')
    ax2.set_title('Training and Validation Dice Score')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Dice Score')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()

# Plot training history
plot_training_history(history)

# Print final metrics
print(f"Final Training Loss: {history['train_losses'][-1]:.6f}")
print(f"Final Validation Loss: {history['val_losses'][-1]:.6f}")
print(f"Final Training Dice: {history['train_dice_scores'][-1]:.4f}")
print(f"Final Validation Dice: {history['val_dice_scores'][-1]:.4f}")
