In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import cv2
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [None]:
class RealTimeConfig:
    """Enhanced configuration for performance"""
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    IMAGE_SIZE = 256  # Increased from 224
    BATCH_SIZE = 16   # Reduced batch size for more detailed learning
    LEARNING_RATE = 1e-3
    NUM_EPOCHS = 150  # Increased epochs
    PATIENCE = 20
    RANDOM_SEED = 42
    INFERENCE_THRESHOLD = 0.5

In [None]:
class TrainingLogger:
    """Advanced training and validation logging"""
    def __init__(self, save_dir='/kaggle/working'):
        self.save_dir = save_dir
        os.makedirs(save_dir, exist_ok=True)
        
        # Metrics tracking
        self.train_losses = []
        self.val_losses = []
        self.train_dice = []
        self.val_dice = []
        self.train_iou = []
        self.val_iou = []
        self.train_precision = []
        self.val_precision = []
        self.train_recall = []
        self.val_recall = []
        self.train_f1 = []
        self.val_f1 = []

    def log_epoch(self, 
                  train_loss, val_loss, 
                  train_metrics, val_metrics):
        """Log metrics for each epoch"""
        # Losses
        self.train_losses.append(train_loss)
        self.val_losses.append(val_loss)
        
        # Dice Coefficient
        self.train_dice.append(train_metrics['dice'])
        self.val_dice.append(val_metrics['dice'])
        
        # IoU
        self.train_iou.append(train_metrics['iou'])
        self.val_iou.append(val_metrics['iou'])
        
        # Precision
        self.train_precision.append(train_metrics['precision'])
        self.val_precision.append(val_metrics['precision'])
        
        # Recall
        self.train_recall.append(train_metrics['recall'])
        self.val_recall.append(val_metrics['recall'])
        
        # F1 Score
        self.train_f1.append(train_metrics['f1'])
        self.val_f1.append(val_metrics['f1'])

    def plot_metrics(self):
        """Visualize training metrics"""
        metrics_to_plot = [
            ('Loss', self.train_losses, self.val_losses),
            ('Dice Coefficient', self.train_dice, self.val_dice),
            ('IoU Score', self.train_iou, self.val_iou),
            ('Precision', self.train_precision, self.val_precision),
            ('Recall', self.train_recall, self.val_recall),
            ('F1 Score', self.train_f1, self.val_f1)
        ]

        plt.figure(figsize=(15, 10))
        for i, (title, train_data, val_data) in enumerate(metrics_to_plot, 1):
            plt.subplot(2, 3, i)
            plt.plot(train_data, label=f'Train {title}')
            plt.plot(val_data, label=f'Validation {title}')
            plt.title(f'{title} Over Epochs')
            plt.xlabel('Epoch')
            plt.ylabel(title)
            plt.legend()

        plt.tight_layout()
        plt.savefig(os.path.join(self.save_dir, 'training_metrics.png'))
        plt.close()

    def save_metrics(self):
        """Save metrics to CSV for further analysis"""
        import pandas as pd
        
        metrics_df = pd.DataFrame({
            'Train Loss': self.train_losses,
            'Val Loss': self.val_losses,
            'Train Dice': self.train_dice,
            'Val Dice': self.val_dice,
            'Train IoU': self.train_iou,
            'Val IoU': self.val_iou,
            'Train Precision': self.train_precision,
            'Val Precision': self.val_precision,
            'Train Recall': self.train_recall,
            'Val Recall': self.val_recall,
            'Train F1': self.train_f1,
            'Val F1': self.val_f1
        })
        
        metrics_df.to_csv(os.path.join(self.save_dir, 'training_metrics.csv'), index=False)

In [None]:
class SegmentationMetrics:
    """Comprehensive metrics for segmentation performance"""
    @staticmethod
    def dice_coefficient(pred, target, smooth=1e-7):
        """Calculate Dice Coefficient"""
        pred = pred.view(-1)
        target = target.view(-1)
        intersection = (pred * target).sum()
        return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)

    @staticmethod
    def iou_score(pred, target, smooth=1e-7):
        """Calculate Intersection over Union (IoU)"""
        pred = pred.view(-1)
        target = target.view(-1)
        intersection = (pred * target).sum()
        total = (pred + target).sum()
        union = total - intersection
        return (intersection + smooth) / (union + smooth)

    @staticmethod
    def precision(pred, target, smooth=1e-7):
        """Calculate Precision"""
        pred = pred.view(-1)
        target = target.view(-1)
        true_positives = (pred * target).sum()
        return (true_positives + smooth) / (pred.sum() + smooth)

    @staticmethod
    def recall(pred, target, smooth=1e-7):
        """Calculate Recall"""
        pred = pred.view(-1)
        target = target.view(-1)
        true_positives = (pred * target).sum()
        return (true_positives + smooth) / (target.sum() + smooth)

    @staticmethod
    def f1_score(precision, recall):
        """Calculate F1 Score"""
        return 2 * (precision * recall) / (precision + recall + 1e-7)

In [None]:
# Image Preprocessing Utilities
class ImagePreprocessor:
    """Advanced image preprocessing techniques"""
    @staticmethod
    def white_balance(img):
        """Perform white balancing using LAB color space"""
        img_LAB = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
        avg_a = np.average(img_LAB[:, :, 1])
        avg_b = np.average(img_LAB[:, :, 2])
        img_LAB[:, :, 1] = img_LAB[:, :, 1] - ((avg_a - 128) * (img_LAB[:, :, 0] / 255.0) * 1.2)
        img_LAB[:, :, 2] = img_LAB[:, :, 2] - ((avg_b - 128) * (img_LAB[:, :, 0] / 255.0) * 1.2)
        return cv2.cvtColor(img_LAB, cv2.COLOR_LAB2BGR)

    @staticmethod
    def unsharp_mask(image, radius=2, amount=1):
        """Apply unsharp masking for image sharpening"""
        return unsharp_mask(image=image, radius=radius, amount=amount)

    @staticmethod
    def apply_clahe(image):
        """Apply Contrast Limited Adaptive Histogram Equalization"""
        lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
        l, a, b = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        l1 = clahe.apply(l)
        lab_planes = cv2.merge((l1, a, b))
        bgr = cv2.cvtColor(lab_planes, cv2.COLOR_LAB2BGR)
        return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)

    @classmethod
    def enhance_image(cls, image):
        """Comprehensive image enhancement pipeline"""
        white_balanced = cls.white_balance(image)
        sharpened = cls.unsharp_mask(white_balanced)
        enhanced = cls.apply_clahe((sharpened * 255).astype(np.uint8))
        return enhanced, sharpened, white_balanced

In [None]:

# Custom Dataset
class KidneyDataset(Dataset):
    def __init__(self, filenames, image_dir, masks_dir, transform=None):
        self.filenames = filenames
        self.image_dir = image_dir
        self.masks_dir = masks_dir
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.filenames[idx])
        mask_path = os.path.join(self.masks_dir, self.filenames[idx])

        # Read images
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.0
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.0

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure tensor with channel dimension
        image = image.unsqueeze(0) if len(image.shape) == 2 else image
        mask = mask.unsqueeze(0) if len(mask.shape) == 2 else mask

        return image, mask

In [None]:
class SpatialAttention(nn.Module):
    """Improved Spatial Attention Module"""
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // 2, kernel_size=1),
            nn.BatchNorm2d(in_channels // 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // 2, 1, kernel_size=7, padding=3),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        # Channel-wise attention
        attention = self.conv(x)
        # Scale the input with the attention map
        return x * attention

In [None]:
class EnhancedUNet(nn.Module):
    """Advanced UNet with attention and residual connections"""
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        
        def conv_block(in_ch, out_ch, dropout_rate=0.2):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Dropout2d(dropout_rate),
                nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True)
            )
        
        # Encoder with increased complexity
        self.enc1 = conv_block(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        
        self.enc2 = conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        
        self.enc3 = conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        
        self.enc4 = conv_block(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        
        # Bridge with spatial attention
        self.bridge = nn.Sequential(
            conv_block(512, 1024),
            SpatialAttention(1024)
        )
        
        # Decoder with skip connections and residual blocks
        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = conv_block(1024, 512)
        
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = conv_block(512, 256)
        
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = conv_block(256, 128)
        
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = conv_block(128, 64)
        
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
    
    def forward(self, x):
        # Ensure input has the right number of channels
        if x.size(1) != 1:
            x = x[:, 0:1, :, :]  # Take only the first channel if multiple are present
        
        # Encoder
        enc1 = self.enc1(x)
        pool1 = self.pool1(enc1)
        
        enc2 = self.enc2(pool1)
        pool2 = self.pool2(enc2)
        
        enc3 = self.enc3(pool2)
        pool3 = self.pool3(enc3)
        
        enc4 = self.enc4(pool3)
        pool4 = self.pool4(enc4)
        
        # Bridge
        bridge = self.bridge(pool4)
        
        # Decoder with skip connections
        upconv4 = self.upconv4(bridge)
        concat4 = torch.cat([upconv4, enc4], dim=1)
        dec4 = self.dec4(concat4)
        
        upconv3 = self.upconv3(dec4)
        concat3 = torch.cat([upconv3, enc3], dim=1)
        dec3 = self.dec3(concat3)
        
        upconv2 = self.upconv2(dec3)
        concat2 = torch.cat([upconv2, enc2], dim=1)
        dec2 = self.dec2(concat2)
        
        upconv1 = self.upconv1(dec2)
        concat1 = torch.cat([upconv1, enc1], dim=1)
        dec1 = self.dec1(concat1)
        
        return torch.sigmoid(self.final_conv(dec1))

In [None]:
def focal_loss(pred, target, alpha=0.25, gamma=2.0):
    """Improved Focal Loss for imbalanced segmentation"""
    BCE_loss = F.binary_cross_entropy(pred, target, reduction='none')
    pt = torch.exp(-BCE_loss)
    F_loss = alpha * (1-pt)**gamma * BCE_loss
    return torch.mean(F_loss)

In [None]:
def train_model(model, train_loader, val_loader, config=RealTimeConfig()):
    device = config.DEVICE
    model = model.to(device)
    
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=config.LEARNING_RATE, 
        weight_decay=1e-4  # Added weight decay for regularization
    )
    
    # Learning rate scheduler
    scheduler = ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=0.5, 
        patience=5, 
        verbose=True
    )
    
    logger = TrainingLogger()
    
    best_val_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(config.NUM_EPOCHS):
        model.train()
        train_losses = []
        train_metrics = {
            'dice': [], 'iou': [], 
            'precision': [], 'recall': [], 'f1': []
        }
        
        for images, masks in train_loader:
            
            # print("Image tensor shape:", images.shape)
            # print("Mask tensor shape:", masks.shape)
  
            images, masks = images.to(device), masks.to(device)
            
            optimizer.zero_grad()
            predictions = model(images)
            
            # Combine Focal Loss and Dice Loss
            focal = focal_loss(predictions, masks)
            dice = 1 - SegmentationMetrics.dice_coefficient(predictions, masks)
            loss = 0.6 * focal + 0.4 * dice
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
            optimizer.step()
            
            # Compute metrics
            pred_binary = (predictions > 0.5).float()
            dice_metric = SegmentationMetrics.dice_coefficient(pred_binary, masks)
            iou = SegmentationMetrics.iou_score(pred_binary, masks)
            precision = SegmentationMetrics.precision(pred_binary, masks)
            recall = SegmentationMetrics.recall(pred_binary, masks)
            f1 = SegmentationMetrics.f1_score(precision, recall)
            
            train_losses.append(loss.item())
            train_metrics['dice'].append(dice_metric.item())
            train_metrics['iou'].append(iou.item())
            train_metrics['precision'].append(precision.item())
            train_metrics['recall'].append(recall.item())
            train_metrics['f1'].append(f1.item())
        
        # Validation phase (similar to training phase)
        model.eval()
        val_losses = []
        val_metrics = {
            'dice': [], 'iou': [], 
            'precision': [], 'recall': [], 'f1': []
        }
        
        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(device), masks.to(device)
                predictions = model(images)
                
                focal = focal_loss(predictions, masks)
                dice = 1 - SegmentationMetrics.dice_coefficient(predictions, masks)
                loss = 0.6 * focal + 0.4 * dice
                
                pred_binary = (predictions > 0.5).float()
                dice_metric = SegmentationMetrics.dice_coefficient(pred_binary, masks)
                iou = SegmentationMetrics.iou_score(pred_binary, masks)
                precision = SegmentationMetrics.precision(pred_binary, masks)
                recall = SegmentationMetrics.recall(pred_binary, masks)
                f1 = SegmentationMetrics.f1_score(precision, recall)
                
                val_losses.append(loss.item())
                val_metrics['dice'].append(dice_metric.item())
                val_metrics['iou'].append(iou.item())
                val_metrics['precision'].append(precision.item())
                val_metrics['recall'].append(recall.item())
                val_metrics['f1'].append(f1.item())
        
        # Compute average metrics (same as before)
        avg_train_loss = np.mean(train_losses)
        avg_val_loss = np.mean(val_losses)
        
        # Learning rate scheduling
        scheduler.step(avg_val_loss)
        
        # Log and print metrics (similar to previous implementation)
        train_metrics_avg = {k: np.mean(v) for k, v in train_metrics.items()}
        val_metrics_avg = {k: np.mean(v) for k, v in val_metrics.items()}
        
        logger.log_epoch(avg_train_loss, avg_val_loss, 
                         train_metrics_avg, val_metrics_avg)
        
        print(f"Epoch {epoch+1}/{config.NUM_EPOCHS}")
        print(f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        print(f"Train Dice: {train_metrics_avg['dice']:.4f}, Val Dice: {val_metrics_avg['dice']:.4f}")
        
        # Model checkpointing
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), '/kaggle/working/best_kidney_segmentation_model.pth')
        else:
            patience_counter += 1
        
        # Early stopping
        if patience_counter >= config.PATIENCE:
            print("Early stopping triggered")
            break
    
    # Plot and save metrics
    logger.plot_metrics()
    logger.save_metrics()
    
    return model

In [None]:
def main():
    # Data augmentation
    transform = A.Compose([
        A.Resize(height=RealTimeConfig.IMAGE_SIZE, width=RealTimeConfig.IMAGE_SIZE),
        A.OneOf([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5)
        ], p=0.7),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=45, p=0.5),
        A.RandomBrightnessContrast(p=0.3),
        A.ToFloat(max_value=1.0),
        ToTensorV2()
    ])
    
    # Dataset paths (update these to your specific paths)
    image_dir = '/kaggle/input/2kdataset/2kdataset/images'
    masks_dir = '/kaggle/input/2kdataset/2kdataset/masks'
    
    # Prepare dataset
    filenames = [f for f in os.listdir(image_dir) if f.endswith('.png')]
    train_files, val_files = train_test_split(filenames, test_size=0.2, random_state=RealTimeConfig.RANDOM_SEED)
    
    train_dataset = KidneyDataset(train_files, image_dir, masks_dir, transform=transform)
    val_dataset = KidneyDataset(val_files, image_dir, masks_dir, transform=transform)
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=RealTimeConfig.BATCH_SIZE, 
        shuffle=True, 
        num_workers=4, 
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=RealTimeConfig.BATCH_SIZE, 
        shuffle=False, 
        num_workers=4, 
        pin_memory=True
    )
    
    # Initialize and train model
    model = EnhancedUNet()
    trained_model = train_model(model, train_loader, val_loader)
    
    print("Model training completed successfully!")

In [None]:

if __name__ == '__main__':
    main()