# Binary Mask Generation using Mask R-CNN

This notebook demonstrates how to train a Mask R-CNN model to generate binary masks for images. We'll use a dataset with 50 images and 50 corresponding binary masks.

The notebook includes:
1. Data preparation and loading
2. Model definition and customization
3. Training the model on our dataset
4. Validation and performance metrics
5. Testing on new images and visualizing results

## 1. Import Dependencies

In [None]:
# Standard libraries
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from tqdm.notebook import tqdm

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# Check for GPU availability
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using device: {device}")

## 2. Create Dataset and DataLoader Classes

We'll create a custom dataset class to load our images and masks.

In [None]:
class BinaryMaskDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        """
        Args:
            img_dir (string): Directory with all the images
            mask_dir (string): Directory with all the masks
            transform (callable, optional): Optional transform to be applied on a sample
        """
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        
        # Get all file names
        self.img_names = sorted([f for f in os.listdir(img_dir) if os.path.isfile(os.path.join(img_dir, f))])
        self.mask_names = sorted([f for f in os.listdir(mask_dir) if os.path.isfile(os.path.join(mask_dir, f))])
        
        # Verify that we have the same number of images and masks
        assert len(self.img_names) == len(self.mask_names), "Number of images and masks should be the same"
    
    def __len__(self):
        return len(self.img_names)
    
    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.img_dir, self.img_names[idx])
        image = Image.open(img_path).convert("RGB")
        
        # Load mask
        mask_path = os.path.join(self.mask_dir, self.mask_names[idx])
        mask = Image.open(mask_path).convert("L")  # Convert to grayscale
        
        # Convert mask to binary (0 and 1)
        mask = np.array(mask)
        mask = (mask > 0).astype(np.uint8)  # Threshold to create binary mask
        
        # Get bounding boxes from mask
        num_objs = 1  # We're treating the entire mask as one object
        
        # Find contours in the mask
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        # Initialize boxes and labels
        boxes = []
        for contour in contours:
            x, y, w, h = cv2.boundingRect(contour)
            # Only include boxes with sufficient area
            if w > 10 and h > 10:  # Minimum size threshold
                boxes.append([x, y, x+w, y+h])
        
        # If no valid boxes were found, create a dummy box for the entire mask
        if len(boxes) == 0:
            non_zero = np.nonzero(mask)
            if len(non_zero[0]) > 0 and len(non_zero[1]) > 0:
                y_min, y_max = np.min(non_zero[0]), np.max(non_zero[0])
                x_min, x_max = np.min(non_zero[1]), np.max(non_zero[1])
                boxes.append([x_min, y_min, x_max, y_max])
            else:
                # If mask is empty, create a small dummy box
                boxes.append([0, 0, 10, 10])
        
        # Convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.ones((len(boxes),), dtype=torch.int64)  # All objects are of the same class
        masks = torch.as_tensor(mask, dtype=torch.uint8)
        
        # If we have multiple objects, create a mask for each object
        if len(boxes) > 1:
            # Create individual masks for each object
            obj_masks = torch.zeros((len(boxes), mask.shape[0], mask.shape[1]), dtype=torch.uint8)
            for i, box in enumerate(boxes):
                x1, y1, x2, y2 = box.int().tolist()
                # Create a mask for this box
                temp_mask = np.zeros_like(mask)
                temp_mask[y1:y2, x1:x2] = mask[y1:y2, x1:x2]
                obj_masks[i] = torch.as_tensor(temp_mask, dtype=torch.uint8)
            masks = obj_masks
        else:
            # Just one object, expand dimensions to make it [1, H, W]
            masks = masks.unsqueeze(0)
        
        # Create target dictionary
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        
        return image, target

## 3. Define Data Transformations

In [None]:
# Define transformations
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

## 4. Load and Split the Dataset

In [None]:
# Set the paths to your image and mask directories
img_dir = "path_to_images_folder"  # Replace with your actual path
mask_dir = "path_to_masks_folder"  # Replace with your actual path

# Create the dataset
full_dataset = BinaryMaskDataset(img_dir=img_dir, mask_dir=mask_dir, transform=data_transform)

# Split the dataset into train, validation and test sets
dataset_size = len(full_dataset)
train_size = int(0.7 * dataset_size)  # 70% for training
val_size = int(0.15 * dataset_size)   # 15% for validation
test_size = dataset_size - train_size - val_size  # Remaining for testing

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    full_dataset, [train_size, val_size, test_size], 
    generator=torch.Generator().manual_seed(42)
)

# Create data loaders
def collate_fn(batch):
    return tuple(zip(*batch))

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

# Print dataset information
print(f"Dataset size: {dataset_size}")
print(f"Training set size: {train_size}")
print(f"Validation set size: {val_size}")
print(f"Test set size: {test_size}")

## 5. Define and Initialize the Mask R-CNN Model

In [None]:
def get_model_instance_segmentation(num_classes):
    # Load pre-trained Mask R-CNN model
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    
    # Get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    
    # Replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    # Get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    
    # Replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(
        in_features_mask,
        hidden_layer,
        num_classes
    )
    
    return model

# Initialize the model
# 2 classes: background and foreground
model = get_model_instance_segmentation(num_classes=2)
model.to(device)

## 6. Create the Training Function

In [None]:
def train_one_epoch(model, optimizer, data_loader, device):
    model.train()
    
    epoch_loss = 0
    
    with tqdm(data_loader, desc="Training") as pbar:
        for images, targets in pbar:
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            # Forward pass
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            
            # Backward pass and optimize
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            
            # Update the progress bar
            epoch_loss += losses.item()
            pbar.set_postfix({"loss": losses.item()})
    
    return epoch_loss / len(data_loader)

## 7. Create the Validation Function

In [None]:
def validate(model, data_loader, device):
    model.eval()
    
    val_loss = 0
    iou_scores = []
    
    with torch.no_grad():
        with tqdm(data_loader, desc="Validation") as pbar:
            for images, targets in pbar:
                images = list(image.to(device) for image in images)
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
                
                # Get model predictions
                outputs = model(images)
                
                # Calculate IoU for each image
                for i, (output, target) in enumerate(zip(outputs, targets)):
                    # Get the predicted mask with highest score
                    if len(output['scores']) > 0:
                        # Get the predicted mask with highest score
                        pred_mask = output['masks'][0, 0].cpu().numpy() > 0.5
                        
                        # Get the ground truth mask
                        gt_mask = target['masks'][0].cpu().numpy()
                        
                        # Calculate IoU
                        intersection = np.logical_and(pred_mask, gt_mask).sum()
                        union = np.logical_or(pred_mask, gt_mask).sum()
                        iou = intersection / union if union > 0 else 0
                        iou_scores.append(iou)
    
    # Calculate mean IoU
    mean_iou = np.mean(iou_scores) if len(iou_scores) > 0 else 0
    
    return mean_iou

## 8. Training Loop

In [None]:
# Set training parameters
num_epochs = 10
learning_rate = 0.005

# Initialize optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(params, lr=learning_rate, momentum=0.9, weight_decay=0.0005)

# Learning rate scheduler
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

# Training history
history = {'train_loss': [], 'val_iou': []}

# Best model tracking
best_iou = 0
best_model_weights = None

# Train the model
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # Train for one epoch
    train_loss = train_one_epoch(model, optimizer, train_dataloader, device)
    history['train_loss'].append(train_loss)
    
    # Validate
    val_iou = validate(model, val_dataloader, device)
    history['val_iou'].append(val_iou)
    
    # Update learning rate
    lr_scheduler.step()
    
    # Print epoch results
    print(f"Train Loss: {train_loss:.4f}, Validation IoU: {val_iou:.4f}")
    
    # Save best model
    if val_iou > best_iou:
        best_iou = val_iou
        best_model_weights = model.state_dict().copy()
        print(f"New best model with IoU: {best_iou:.4f}")

# Load best model weights
model.load_state_dict(best_model_weights)
print(f"\nTraining completed. Best validation IoU: {best_iou:.4f}")

# Save the model
torch.save(model.state_dict(), 'mask_rcnn_binary_mask_model.pth')

## 9. Plot Training History

In [None]:
# Plot training loss
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'])
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')

# Plot validation IoU
plt.subplot(1, 2, 2)
plt.plot(history['val_iou'])
plt.title('Validation IoU')
plt.xlabel('Epoch')
plt.ylabel('Mean IoU')

plt.tight_layout()
plt.show()

## 10. Test on Validation Set and Visualize Results

In [None]:
# Function to visualize model predictions
def visualize_prediction(model, image, target, device):
    # Set model to evaluation mode
    model.eval()
    
    # Send image to device
    image = image.to(device)
    
    # Get prediction
    with torch.no_grad():
        prediction = model([image])[0]
    
    # Convert image back to numpy for visualization
    # Denormalize image
    image = image.cpu().numpy().transpose(1, 2, 0)
    image = image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    image = np.clip(image, 0, 1)
    
    # Get ground truth mask
    gt_mask = target['masks'][0].cpu().numpy()
    
    # Get predicted mask
    if len(prediction['masks']) > 0 and prediction['scores'][0] > 0.5:
        pred_mask = prediction['masks'][0, 0].cpu().numpy() > 0.5
    else:
        pred_mask = np.zeros_like(gt_mask)
    
    # Calculate IoU
    intersection = np.logical_and(pred_mask, gt_mask).sum()
    union = np.logical_or(pred_mask, gt_mask).sum()
    iou = intersection / union if union > 0 else 0
    
    # Visualize results
    plt.figure(figsize=(15, 5))
    
    # Original image
    plt.subplot(1, 3, 1)
    plt.imshow(image)
    plt.title('Original Image')
    plt.axis('off')
    
    # Ground truth mask
    plt.subplot(1, 3, 2)
    plt.imshow(gt_mask, cmap='gray')
    plt.title('Ground Truth Mask')
    plt.axis('off')
    
    # Predicted mask
    plt.subplot(1, 3, 3)
    plt.imshow(pred_mask, cmap='gray')
    plt.title(f'Predicted Mask (IoU: {iou:.3f})')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return iou

In [None]:
# Test on several validation images
num_test_images = min(5, len(test_dataset))  # Test on up to 5 images
test_iou_scores = []

for i in range(num_test_images):
    image, target = test_dataset[i]
    iou = visualize_prediction(model, image, target, device)
    test_iou_scores.append(iou)
    print(f"Test image {i+1}, IoU: {iou:.4f}")

mean_test_iou = np.mean(test_iou_scores)
print(f"\nMean test IoU: {mean_test_iou:.4f}")

## 11. Additional Function: Use Your Provided Mask R-CNN Code

In [None]:
def apply_maskrcnn(image):
    """Apply pre-trained Mask R-CNN model"""
    # Load a pre-trained Mask R-CNN model
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    model.eval()

    if torch.cuda.is_available():
        model.to('cuda')

    # Prepare image for the model
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    input_tensor = transform(image)
    input_batch = input_tensor.unsqueeze(0)

    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')

    with torch.no_grad():
        prediction = model(input_batch)

    # Initialize an empty mask with the same dimensions as the image
    mask = np.zeros((image.height, image.width), dtype=np.uint8)

    # Process predictions
    for i, score in enumerate(prediction[0]['scores']):
        if score > 0.5:  # confidence threshold
            mask_tensor = prediction[0]['masks'][i, 0].cpu().numpy()
            mask_binary = (mask_tensor > 0.5).astype(np.uint8)
            mask = np.logical_or(mask, mask_binary).astype(np.uint8)

    # Convert to binary mask
    binary_mask = mask * 255

    return binary_mask

In [None]:
# Compare pre-trained model with our fine-tuned model
def compare_models(image_path, device):
    # Load image
    pil_image = Image.open(image_path).convert("RGB")
    cv_image = cv2.imread(image_path)
    cv_image_rgb = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
    
    # Get mask from pre-trained model
    pretrained_mask = apply_maskrcnn(pil_image)
    
    # Get mask from our fine-tuned model
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    input_tensor = transform(pil_image).to(device)
    
    with torch.no_grad():
        prediction = model([input_tensor])[0]
    
    if len(prediction['masks']) > 0 and prediction['scores'][0] > 0.5:
        finetuned_mask = (prediction['masks'][0, 0].cpu().numpy() > 0.5) * 255
    else:
        finetuned_mask = np.zeros((pil_image.height, pil_image.width), dtype=np.uint8)
    
    # Display results
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.imshow(cv_image_rgb)
    plt.title('Original Image')
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.imshow(pretrained_mask, cmap='gray')
    plt.title('Pre-trained Mask R-CNN')
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    plt.imshow(finetuned_mask, cmap='gray')
    plt.title('Fine-tuned Mask R-CNN')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Test comparison on a sample image (replace with an actual test image path)
# sample_image_path = "path_to_test_image.jpg"  # Replace with your test image path
# compare_models(sample_image_path, device)

## 12. Conclusion and Next Steps

In this notebook, we have:
1. Created a custom dataset for binary mask segmentation
2. Loaded and prepared the data with proper transformations
3. Initialized a Mask R-CNN model with a pre-trained backbone
4. Fine-tuned the model on our dataset
5. Evaluated the model using IoU metric
6. Visualized the results and compared with a pre-trained model

Potential next steps:
- Experiment with different backbones (ResNet-101, etc.)
- Try different data augmentation techniques
- Tune hyperparameters (learning rate, batch size, etc.)
- Apply post-processing to improve mask quality
- Implement ensemble methods for better predictions