# Semantic Segmentation with U-Net

This notebook implements a complete semantic segmentation model using PyTorch. We'll implement a U-Net architecture, which is particularly effective for semantic segmentation tasks. The model will be trained on a sample dataset to segment images into different classes.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os

# For downloading sample data
import urllib.request
import zipfile

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=2):
        super().__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(64, 128)
        )
        self.down2 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(128, 256)
        )
        self.down3 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(256, 512)
        )
        self.down4 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(512, 1024)
        )

        self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.conv1 = DoubleConv(1024, 512)
        self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv2 = DoubleConv(512, 256)
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv3 = DoubleConv(256, 128)
        self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv4 = DoubleConv(128, 64)
        
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        # Decoder
        x = self.up1(x5)
        x = self.conv1(torch.cat([x4, x], dim=1))
        x = self.up2(x)
        x = self.conv2(torch.cat([x3, x], dim=1))
        x = self.up3(x)
        x = self.conv3(torch.cat([x2, x], dim=1))
        x = self.up4(x)
        x = self.conv4(torch.cat([x1, x], dim=1))
        
        return self.outc(x)

In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx].replace('.jpg', '_mask.png'))
        
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')  # Convert mask to grayscale
        
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        
        return image, mask

In [None]:
def train_model(model, train_loader, criterion, optimizer, device, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch_idx, (images, masks) in enumerate(train_loader):
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            if batch_idx % 10 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_loader)}], '
                      f'Loss: {loss.item():.4f}')
        
        epoch_loss = running_loss / len(train_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {epoch_loss:.4f}')

def visualize_prediction(model, image, device):
    model.eval()
    with torch.no_grad():
        image = image.unsqueeze(0).to(device)
        output = model(image)
        pred_mask = torch.argmax(output, dim=1)
        
        # Convert tensors to numpy arrays for visualization
        image = image.cpu().squeeze().permute(1, 2, 0).numpy()
        pred_mask = pred_mask.cpu().squeeze().numpy()
        
        # Plot
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
        ax1.imshow(image)
        ax1.set_title('Original Image')
        ax1.axis('off')
        
        ax2.imshow(pred_mask, cmap='gray')
        ax2.set_title('Predicted Mask')
        ax2.axis('off')
        
        plt.show()

In [None]:
# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize model
model = UNet(n_channels=3, n_classes=2).to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Define transforms
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# Example of how to use the model (commented out as we need actual data)
'''
# Create dataset and dataloader
dataset = SegmentationDataset(
    image_dir='path/to/images',
    mask_dir='path/to/masks',
    transform=transform
)
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)

# Train the model
train_model(model, data_loader, criterion, optimizer, device)

# Visualize a prediction
sample_image, _ = dataset[0]
visualize_prediction(model, sample_image, device)
'''

# USE COLAB

## Usage Instructions

1. This implementation provides a complete U-Net model for semantic segmentation.
2. The model can handle any number of input channels and output classes.
3. To use this model with your own data:
   - Organize your images and corresponding masks in separate directories
   - Create a SegmentationDataset instance with your directories
   - Use the provided training utilities to train the model
   - Use the visualization function to see the results

### Model Architecture

The implemented U-Net architecture consists of:
- Encoder path with 4 downsampling blocks
- Decoder path with 4 upsampling blocks
- Skip connections between encoder and decoder
- Double convolution blocks with batch normalization

### Requirements

```bash
pip install torch torchvision pillow matplotlib numpy
```