In [None]:
!pip install torch
!pip install torchvision
!pip install opencv-python
!pip install albumentations
!pip install numpy
!pip install scikit-learn
!pip install matplotlib
!pip install pandas

In [None]:
import kagglehub

# Download latest version
path = kagglehub.model_download("karthickjeeva/kidneyseg/pyTorch/default")

print("Path to model files:", path)

In [None]:
# Kidney Segmentation - Inference Notebook
# This notebook provides inference capabilities for the trained kidney segmentation model

# ============================================================================
# SECTION 1: Import Required Libraries
# ============================================================================

import os
import datetime
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm

print(f"PyTorch version: {torch.__version__}")

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ============================================================================
# SECTION 2: Model Architecture Definitions
# ============================================================================

class ResidualConnection(nn.Module):
    """Residual Connection module to handle channel adjustments"""
    def __init__(self, in_channels, out_channels):
        super(ResidualConnection, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else None
    
    def forward(self, x, identity):
        if self.conv is not None:
            identity = self.conv(identity)
        return x + identity


class ResidualBlock(nn.Module):
    """Residual Block with convolutional layers and batch normalization"""
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        self.residual = ResidualConnection(in_channels, out_channels)
        
    def forward(self, x):
        identity = x
        x = self.conv_block(x)
        return self.residual(x, identity)


class EnhancedUNet(nn.Module):
    """Enhanced UNet with Residual Blocks and Deep Supervision"""
    def __init__(self, in_channels=1, out_channels=1):
        super(EnhancedUNet, self).__init__()
        
        # Encoder
        self.encoder1 = ResidualBlock(in_channels, 64)
        self.encoder2 = ResidualBlock(64, 128)
        self.encoder3 = ResidualBlock(128, 256)
        self.encoder4 = ResidualBlock(256, 512)
        self.encoder5 = ResidualBlock(512, 1024)
        
        # Deep Supervision
        self.deep_sup4 = nn.Conv2d(512, out_channels, 1)
        self.deep_sup3 = nn.Conv2d(256, out_channels, 1)
        self.deep_sup2 = nn.Conv2d(128, out_channels, 1)
        
        # Decoder
        self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.decoder4 = ResidualBlock(1024, 512)
        self.norm4 = nn.LayerNorm([512, 32, 32])
        
        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.decoder3 = ResidualBlock(512, 256)
        self.norm3 = nn.LayerNorm([256, 64, 64])
        
        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.decoder2 = ResidualBlock(256, 128)
        self.norm2 = nn.LayerNorm([128, 128, 128])
        
        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.decoder1 = ResidualBlock(128, 64)
        self.norm1 = nn.LayerNorm([64, 256, 256])
        
        self.final_conv = nn.Conv2d(64, out_channels, 1)
        self.dropout = nn.Dropout2d(0.3)
    
    def forward(self, x):
        # Encoder
        enc1 = self.dropout(self.encoder1(x))
        enc2 = self.dropout(self.encoder2(F.max_pool2d(enc1, 2)))
        enc3 = self.dropout(self.encoder3(F.max_pool2d(enc2, 2)))
        enc4 = self.dropout(self.encoder4(F.max_pool2d(enc3, 2)))
        enc5 = self.dropout(self.encoder5(F.max_pool2d(enc4, 2)))
        
        # Decoder
        dec4 = self.upconv4(enc5)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.norm4(self.decoder4(dec4))
        deep_out4 = self.deep_sup4(dec4)
        
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.norm3(self.decoder3(dec3))
        deep_out3 = self.deep_sup3(dec3)
        
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.norm2(self.decoder2(dec2))
        deep_out2 = self.deep_sup2(dec2)
        
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.norm1(self.decoder1(dec1))
        
        final_out = self.final_conv(dec1)
        
        return (final_out, deep_out4, deep_out3, deep_out2) if self.training else final_out


# ============================================================================
# SECTION 3: Inference Class
# ============================================================================

class KidneyEvaluator:
    """Kidney Segmentation Inference Class"""
    
    def __init__(self, model_path, device='cuda'):
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.model = self.load_model(model_path)
        self.model.eval()
        print(f"Model loaded successfully on {self.device}")
        
    def load_model(self, model_path):
        """Load trained model from checkpoint"""
        checkpoint = torch.load(model_path, map_location=self.device)
        model = EnhancedUNet(in_channels=1, out_channels=1)
        
        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        print(f"Total number of parameters: {total_params:,}")
        
        # Load state dict
        model.load_state_dict(checkpoint['model_state_dict'])
        model = model.to(self.device)
        
        return model
    
    def preprocess_image(self, image):
        """Preprocess image for model input"""
        # Convert to grayscale if needed
        if len(image.shape) == 3:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        
        # Resize to model input size
        image = cv2.resize(image, (256, 256))
        
        # Normalize and convert to tensor
        image = image.astype(np.float32) / 255.0
        image = torch.from_numpy(image).unsqueeze(0).unsqueeze(0)
        
        return image.to(self.device)
    
    def process_single_image(self, image_path, save_path=None, show_plot=True):
        """Process a single image and return segmentation results"""
        
        # Read image
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image not found: {image_path}")
            
        original = cv2.imread(image_path)
        if original is None:
            raise ValueError(f"Failed to read image: {image_path}")
            
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        
        # Store original dimensions
        original_height, original_width = original.shape[:2]
        
        # Preprocess
        start_time = datetime.datetime.now()
        processed = self.preprocess_image(image)
        
        # Inference
        with torch.no_grad():
            output = self.model(processed)
            probabilities = torch.sigmoid(output)
            prediction = (probabilities > 0.5).float()
        
        end_time = datetime.datetime.now()
        inference_time = (end_time - start_time).total_seconds()
        
        # Convert predictions back to numpy and resize to original dimensions
        pred_mask = prediction.squeeze().cpu().numpy()
        conf_map = probabilities.squeeze().cpu().numpy()
        
        # Resize predictions to match original image dimensions
        pred_mask = cv2.resize(pred_mask, (original_width, original_height))
        conf_map = cv2.resize(conf_map, (original_width, original_height))
        
        confidence_score = float(conf_map.mean())
        
        # Create visualization
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        fig.suptitle(f'Inference Time: {inference_time:.4f}s | Confidence: {confidence_score:.3f}', 
                     fontsize=14, fontweight='bold')
        
        # Original
        axes[0, 0].imshow(cv2.cvtColor(original, cv2.COLOR_BGR2RGB))
        axes[0, 0].set_title('Original Image', fontsize=12)
        axes[0, 0].axis('off')
        
        # Predicted Mask
        axes[0, 1].imshow(pred_mask, cmap='gray')
        axes[0, 1].set_title('Predicted Mask', fontsize=12)
        axes[0, 1].axis('off')
        
        # Confidence Map
        conf_map_display = axes[0, 2].imshow(conf_map, cmap='jet', vmin=0, vmax=1)
        axes[0, 2].set_title('Confidence Map', fontsize=12)
        axes[0, 2].axis('off')
        plt.colorbar(conf_map_display, ax=axes[0, 2], fraction=0.046, pad=0.04)
        
        # Overlay
        overlay = cv2.cvtColor(original, cv2.COLOR_BGR2RGB).copy()
        mask_overlay = np.zeros_like(overlay)
        mask_overlay[pred_mask > 0.5] = [255, 255, 0]  # Yellow overlay
        overlay = cv2.addWeighted(overlay, 0.85, mask_overlay, 0.15, 0)
        axes[1, 0].imshow(overlay)
        axes[1, 0].set_title('Mask Overlay', fontsize=12)
        axes[1, 0].axis('off')
        
        # Contour visualization
        contour_img = cv2.cvtColor(original, cv2.COLOR_BGR2RGB).copy()
        contours, _ = cv2.findContours((pred_mask > 0.5).astype(np.uint8), 
                                       cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(contour_img, contours, -1, (255, 255, 0), 2)
        axes[1, 1].imshow(contour_img)
        axes[1, 1].set_title('Contour Visualization', fontsize=12)
        axes[1, 1].axis('off')
        
        # Statistics
        axes[1, 2].axis('off')
        stats_text = f"""
        Image Statistics:
        ─────────────────
        Original Size: {original_width}x{original_height}
        Model Input: 256x256
        
        Segmentation Metrics:
        ─────────────────
        Mean Confidence: {confidence_score:.3f}
        Max Confidence: {conf_map.max():.3f}
        Min Confidence: {conf_map.min():.3f}
        
        Mask Coverage: {(pred_mask > 0.5).sum() / pred_mask.size * 100:.2f}%
        Number of Contours: {len(contours)}
        """
        axes[1, 2].text(0.1, 0.5, stats_text, fontsize=10, 
                       verticalalignment='center', family='monospace')
        
        plt.tight_layout()
        
        # Save figure if save_path is provided
        if save_path:
            os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True)
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Result saved to: {save_path}")
        
        if show_plot:
            plt.show()
        else:
            plt.close(fig)
        
        return {
            'original': original,
            'pred_mask': pred_mask,
            'conf_map': conf_map,
            'confidence_score': confidence_score,
            'inference_time': inference_time,
            'contours': contours,
            'figure': fig
        }
    
    def process_batch(self, image_paths, output_dir='./results'):
        """Process multiple images and save results"""
        
        os.makedirs(output_dir, exist_ok=True)
        results = []
        
        print(f"\nProcessing {len(image_paths)} images...")
        
        for img_path in tqdm(image_paths):
            try:
                img_name = os.path.basename(img_path)
                save_path = os.path.join(output_dir, f'result_{img_name}')
                
                result = self.process_single_image(img_path, save_path=save_path, show_plot=False)
                
                results.append({
                    'image': img_name,
                    'confidence': result['confidence_score'],
                    'inference_time': result['inference_time'],
                    'num_contours': len(result['contours'])
                })
                
            except Exception as e:
                print(f"\nError processing {img_path}: {str(e)}")
                continue
        
        # Save summary
        summary_path = os.path.join(output_dir, 'batch_summary.txt')
        with open(summary_path, 'w') as f:
            f.write('Kidney Segmentation - Batch Processing Summary\n')
            f.write('=' * 60 + '\n\n')
            
            for result in results:
                f.write(f"Image: {result['image']}\n")
                f.write(f"  Confidence: {result['confidence']:.3f}\n")
                f.write(f"  Inference Time: {result['inference_time']:.4f}s\n")
                f.write(f"  Number of Contours: {result['num_contours']}\n\n")
            
            # Overall statistics
            if results:
                avg_confidence = np.mean([r['confidence'] for r in results])
                avg_time = np.mean([r['inference_time'] for r in results])
                
                f.write('\n' + '=' * 60 + '\n')
                f.write('Overall Statistics:\n')
                f.write(f"  Total Images Processed: {len(results)}\n")
                f.write(f"  Average Confidence: {avg_confidence:.3f}\n")
                f.write(f"  Average Inference Time: {avg_time:.4f}s\n")
        
        print(f"\nBatch processing complete!")
        print(f"Results saved to: {output_dir}")
        print(f"Summary saved to: {summary_path}")
        
        return results


# ============================================================================
# SECTION 4: Usage Examples
# ============================================================================

def inference_single_image(model_path, image_path, save_path=None):
    """
    Quick inference function for a single image
    
    Args:
        model_path: Path to the trained model checkpoint
        image_path: Path to the input image
        save_path: Optional path to save the result
    """
    evaluator = KidneyEvaluator(model_path)
    result = evaluator.process_single_image(image_path, save_path=save_path)
    return result


def inference_batch(model_path, image_dir, output_dir='./results'):
    """
    Batch inference function for multiple images
    
    Args:
        model_path: Path to the trained model checkpoint
        image_dir: Directory containing input images
        output_dir: Directory to save results
    """
    # Get all image files
    valid_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
    image_paths = [
        os.path.join(image_dir, f) 
        for f in os.listdir(image_dir) 
        if os.path.splitext(f)[1].lower() in valid_extensions
    ]
    
    if not image_paths:
        print(f"No valid images found in {image_dir}")
        return
    
    evaluator = KidneyEvaluator(model_path)
    results = evaluator.process_batch(image_paths, output_dir)
    return results


# ============================================================================
# SECTION 5: Main Execution
# ============================================================================

if __name__ == "__main__":
    # Configuration
    MODEL_PATH = '/kaggle/input/kidneyseg/pytorch/default/20/best_model_checkpoint_1MW.pth'
    
    # Example 1: Single image inference
    print("\n" + "="*60)
    print("EXAMPLE 1: Single Image Inference")
    print("="*60)
    
    IMAGE_PATH = '/kaggle/input/testimages/240_F_603611028_GGVVOAxWeuRbSFgkZVrcJIkMSLJfeTDG.jpg'
    SAVE_PATH = '/kaggle/working/kidney_result.png'
    
    result = inference_single_image(MODEL_PATH, IMAGE_PATH, SAVE_PATH)
    print(f"\nInference completed successfully!")
    print(f"Confidence Score: {result['confidence_score']:.3f}")
    print(f"Inference Time: {result['inference_time']:.4f}s")
    
    # Example 2: Batch inference (uncomment to use)
    # print("\n" + "="*60)
    # print("EXAMPLE 2: Batch Image Inference")
    # print("="*60)
    # 
    # IMAGE_DIR = '/kaggle/input/testimages/'
    # OUTPUT_DIR = '/kaggle/working/batch_results/'
    # 
    # results = inference_batch(MODEL_PATH, IMAGE_DIR, OUTPUT_DIR)
    
    print("\n" + "="*60)
    print("Inference notebook execution completed!")
    print("="*60)