# Segmentation

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import zipfile
import requests
from io import BytesIO
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.nn.functional as F
from PIL import Image
import shutil

# Download and extract segmentation dataset
def download_and_extract_segmentation_dataset():
    # Create directory for the dataset
    os.makedirs("segmentation_data", exist_ok=True)
    
    # Download MFSD dataset
    url = "https://github.com/sadjadrz/MFSD/archive/refs/heads/main.zip"
    print("Downloading face mask segmentation dataset...")
    r = requests.get(url, stream=True)
    with zipfile.ZipFile(BytesIO(r.content)) as zip_ref:
        zip_ref.extractall("./")
    
    # Create directories for processed data
    os.makedirs("segmentation_data/images", exist_ok=True)
    os.makedirs("segmentation_data/masks", exist_ok=True)
    
    # Extract images and masks
    source_dir = "./MFSD-main/dataset"
    
    # Copy files to our working directory
    for filename in os.listdir(os.path.join(source_dir, "images")):
        src_path = os.path.join(source_dir, "images", filename)
        dst_path = os.path.join("segmentation_data/images", filename)
        shutil.copy(src_path, dst_path)
    
    for filename in os.listdir(os.path.join(source_dir, "masks")):
        src_path = os.path.join(source_dir, "masks", filename)
        dst_path = os.path.join("segmentation_data/masks", filename)
        shutil.copy(src_path, dst_path)
    
    print("Segmentation dataset downloaded and extracted successfully.")

# Run the download function
download_and_extract_segmentation_dataset()

# Task 3: Traditional Segmentation Methods

def load_segmentation_dataset(images_dir, masks_dir, limit=None):
    """Load segmentation dataset"""
    images = []
    masks = []
    filenames = []
    
    # List all image files
    image_files = os.listdir(images_dir)
    if limit:
        image_files = image_files[:limit]
    
    for filename in tqdm(image_files, desc="Loading segmentation dataset"):
        # Load image
        img_path = os.path.join(images_dir, filename)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Load corresponding mask
        mask_path = os.path.join(masks_dir, filename)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        # Apply threshold to create binary mask
        _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
        
        # Resize for consistency
        img = cv2.resize(img, (256, 256))
        mask = cv2.resize(mask, (256, 256))
        
        images.append(img)
        masks.append(mask)
        filenames.append(filename)
    
    return np.array(images), np.array(masks), filenames

# Load segmentation dataset
images, ground_truth_masks, filenames = load_segmentation_dataset(
    "segmentation_data/images", 
    "segmentation_data/masks",
    limit=100  # Limit for faster processing during development
)

# Display some sample images and their masks
def display_samples(images, masks, num_samples=3):
    """Display sample images and their masks"""
    plt.figure(figsize=(15, 5 * num_samples))
    
    for i in range(num_samples):
        # Original image
        plt.subplot(num_samples, 3, i*3+1)
        plt.imshow(images[i])
        plt.title(f"Original Image {i+1}")
        plt.axis('off')
        
        # Ground truth mask
        plt.subplot(num_samples, 3, i*3+2)
        plt.imshow(masks[i], cmap='gray')
        plt.title(f"Ground Truth Mask {i+1}")
        plt.axis('off')
        
        # Overlay mask on image
        plt.subplot(num_samples, 3, i*3+3)
        overlay = images[i].copy()
        overlay[masks[i] > 0] = (255, 0, 0)  # Red overlay
        plt.imshow(overlay)
        plt.title(f"Overlay {i+1}")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Show sample images and masks
display_samples(images, ground_truth_masks)

# Implement traditional segmentation methods

def color_thresholding(image):
    """Segment mask using color thresholding"""
    # Convert to HSV color space
    hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
    
    # Define color range for face masks (typical blue/white masks)
    lower_blue = np.array([90, 50, 50])
    upper_blue = np.array([130, 255, 255])
    lower_white = np.array([0, 0, 180])
    upper_white = np.array([180, 30, 255])
    
    # Create masks
    blue_mask = cv2.inRange(hsv, lower_blue, upper_blue)
    white_mask = cv2.inRange(hsv, lower_white, upper_white)
    
    # Combine masks
    combined_mask = cv2.bitwise_or(blue_mask, white_mask)
    
    # Apply morphological operations to clean up the mask
    kernel = np.ones((5, 5), np.uint8)
    combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_OPEN, kernel)
    combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_CLOSE, kernel)
    
    return combined_mask

def edge_based_segmentation(image):
    """Segment mask using edge detection and contour finding"""
    # Convert to grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    
    # Apply Gaussian blur
    blurred = cv2.GaussianBlur(gray, (5, 5), 0)
    
    # Detect edges using Canny
    edges = cv2.Canny(blurred, 50, 150)
    
    # Dilate the edges to connect gaps
    kernel = np.ones((5, 5), np.uint8)
    dilated = cv2.dilate(edges, kernel, iterations=1)
    
    # Find contours
    contours, _ = cv2.findContours(dilated.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # Create a mask from contours
    mask = np.zeros_like(gray)
    
    # Filter contours based on area
    for contour in contours:
        area = cv2.contourArea(contour)
        if area > 500:  # Filter out small contours
            cv2.drawContours(mask, [contour], -1, 255, thickness=cv2.FILLED)
    
    # Apply morphological operations to clean up
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    
    return mask

def region_growing(image, seed_points=None):
    """Segment mask using region growing algorithm"""
    # Convert to grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    
    # If no seed points provided, use the center of the image
    if seed_points is None:
        h, w = gray.shape
        seed_points = [(w//2, h//2)]
    
    # Create mask
    mask = np.zeros_like(gray)
    
    # Define region growing parameters
    threshold = 10  # Intensity threshold
    
    # Process each seed point
    for seed in seed_points:
        x, y = seed
        seed_value = gray[y, x]
        
        # Initialize queue with seed point
        queue = [(x, y)]
        processed = set([(x, y)])
        
        while queue:
            curr_x, curr_y = queue.pop(0)
            mask[curr_y, curr_x] = 255
            
            # Check 8-connected neighbors
            neighbors = [
                (curr_x+1, curr_y), (curr_x-1, curr_y),
                (curr_x, curr_y+1), (curr_x, curr_y-1),
                (curr_x+1, curr_y+1), (curr_x-1, curr_y-1),
                (curr_x+1, curr_y-1), (curr_x-1, curr_y+1)
            ]
            
            for nx, ny in neighbors:
                # Check if within image bounds
                if 0 <= nx < gray.shape[1] and 0 <= ny < gray.shape[0]:
                    # Check if not processed and within threshold
                    if (nx, ny) not in processed and abs(int(gray[ny, nx]) - int(seed_value)) < threshold:
                        queue.append((nx, ny))
                        processed.add((nx, ny))
    
    # Apply morphological operations to clean up
    kernel = np.ones((5, 5), np.uint8)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    
    return mask

def watershed_segmentation(image):
    """Segment mask using watershed algorithm"""
    # Convert to grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    
    # Apply threshold
    _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    
    # Noise removal
    kernel = np.ones((3, 3), np.uint8)
    opening = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=2)
    
    # Sure background area
    sure_bg = cv2.dilate(opening, kernel, iterations=3)
    
    # Finding sure foreground area
    dist_transform = cv2.distanceTransform(opening, cv2.DIST_L2, 5)
    _, sure_fg = cv2.threshold(dist_transform, 0.7 * dist_transform.max(), 255, 0)
    
    # Finding unknown region
    sure_fg = np.uint8(sure_fg)
    unknown = cv2.subtract(sure_bg, sure_fg)
    
    # Marker labelling
    _, markers = cv2.connectedComponents(sure_fg)
    
    # Add one to all labels so that background is not 0, but 1
    markers = markers + 1
    
    # Mark the unknown region with 0
    markers[unknown == 255] = 0
    
    # Apply watershed
    markers = cv2.watershed(image, markers)
    
    # Create mask: regions marked as 1 (background) are set to 0, others to 255
    mask = np.zeros_like(gray)
    mask[markers > 1] = 255
    
    # Clean up with morphological operations
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    
    return mask

# Apply traditional segmentation methods to dataset
def apply_segmentation_methods(images):
    """Apply different segmentation methods to images"""
    color_masks = []
    edge_masks = []
    watershed_masks = []
    
    for img in tqdm(images, desc="Applying segmentation methods"):
        color_masks.append(color_thresholding(img))
        edge_masks.append(edge_based_segmentation(img))
        watershed_masks.append(watershed_segmentation(img))
    
    return {
        'Color Thresholding': np.array(color_masks),
        'Edge-based': np.array(edge_masks),
        'Watershed': np.array(watershed_masks)
    }

# Calculate evaluation metrics
def calculate_metrics(pred_mask, gt_mask):
    """Calculate IoU and Dice score"""
    # Convert to binary
    pred_mask = (pred_mask > 0).astype(np.uint8)
    gt_mask = (gt_mask > 0).astype(np.uint8)
    
    # Calculate intersection and union
    intersection = np.logical_and(pred_mask, gt_mask).sum()
    union = np.logical_or(pred_mask, gt_mask).sum()
    
    # Calculate IoU
    iou = intersection / union if union > 0 else 0
    
    # Calculate Dice score
    dice = (2 * intersection) / (pred_mask.sum() + gt_mask.sum()) if (pred_mask.sum() + gt_mask.sum()) > 0 else 0
    
    return iou, dice

# Evaluate traditional segmentation methods
def evaluate_segmentation_methods(segmented_masks, ground_truth_masks):
    """Evaluate traditional segmentation methods"""
    results = {}
    
    for method_name, pred_masks in segmented_masks.items():
        ious = []
        dice_scores = []
        
        for i in range(len(pred_masks)):
            iou, dice = calculate_metrics(pred_masks[i], ground_truth_masks[i])
            ious.append(iou)
            dice_scores.append(dice)
        
        results[method_name] = {
            'IoU': np.mean(ious),
            'Dice': np.mean(dice_scores)
        }
        
        print(f"{method_name} - Average IoU: {np.mean(ious):.4f}, Average Dice: {np.mean(dice_scores):.4f}")
    
    return results

# Display segmentation results
def display_segmentation_results(images, ground_truth_masks, segmented_masks, indices, metrics):
    """Display segmentation results for selected images"""
    methods = list(segmented_masks.keys())
    num_methods = len(methods)
    num_images = len(indices)
    
    plt.figure(figsize=(15, 5 * num_images))
    
    for idx, image_idx in enumerate(indices):
        # Original image
        plt.subplot(num_images, num_methods + 2, idx * (num_methods + 2) + 1)
        plt.imshow(images[image_idx])
        plt.title(f"Original {image_idx}")
        plt.axis('off')
        
        # Ground truth mask
        plt.subplot(num_images, num_methods + 2, idx * (num_methods + 2) + 2)
        plt.imshow(ground_truth_masks[image_idx], cmap='gray')
        plt.title("Ground Truth")
        plt.axis('off')
        
        # Segmentation results
        for method_idx, method_name in enumerate(methods):
            plt.subplot(num_images, num_methods + 2, idx * (num_methods + 2) + method_idx + 3)
            plt.imshow(segmented_masks[method_name][image_idx], cmap='gray')
            
            # Calculate metrics for this specific image
            iou, dice = calculate_metrics(segmented_masks[method_name][image_idx], ground_truth_masks[image_idx])
            plt.title(f"{method_name}\nIoU: {iou:.2f}, Dice: {dice:.2f}")
            plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Overall performance comparison
    plt.figure(figsize=(12, 6))
    
    # Plot IoU scores
    plt.subplot(1, 2, 1)
    plt.bar(metrics.keys(), [metrics[m]['IoU'] for m in metrics.keys()])
    plt.title('Average IoU Score')
    plt.ylim(0, 1)
    
    # Plot Dice scores
    plt.subplot(1, 2, 2)
    plt.bar(metrics.keys(), [metrics[m]['Dice'] for m in metrics.keys()])
    plt.title('Average Dice Score')
    plt.ylim(0, 1)
    
    plt.tight_layout()
    plt.show()

# Apply traditional segmentation methods
segmented_masks = apply_segmentation_methods(images)

# Evaluate methods
metrics = evaluate_segmentation_methods(segmented_masks, ground_truth_masks)

# Display results for a few sample images
display_segmentation_results(images, ground_truth_masks, segmented_masks, [0, 1, 2], metrics)

# Task 4: Mask Segmentation Using U-Net

# Define U-Net model
class DoubleConv(nn.Module):
    """Double convolution block for U-Net"""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()
        
        # Encoder (downsampling)
        self.enc1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.enc4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = DoubleConv(512, 1024)
        
        # Decoder (upsampling)
        self.up_conv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(1024, 512)  # 512 + 512 = 1024 input channels after concat
        
        self.up_conv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(512, 256)  # 256 + 256 = 512 input channels after concat
        
        self.up_conv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(256, 128)  # 128 + 128 = 256 input channels after concat
        
        self.up_conv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(128, 64)  # 64 + 64 = 128 input channels after concat
        
        # Final output layer
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
    
    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool1(enc1))
        enc3 = self.enc3(self.pool2(enc2))
        enc4 = self.enc4(self.pool3(enc3))
        
        # Bottleneck
        bottleneck = self.bottleneck(self.pool4(enc4))
        
        # Decoder with skip connections
        dec4 = self.up_conv4(bottleneck)
        dec4 = torch.cat([dec4, enc4], dim=1)
        dec4 = self.dec4(dec4)
        
        dec3 = self.up_conv3(dec4)
        dec3 = torch.cat([dec3, enc3], dim=1)
        dec3 = self.dec3(dec3)
        
        dec2 = self.up_conv2(dec3)
        dec2 = torch.cat([dec2, enc2], dim=1)
        dec2 = self.dec2(dec2