# Semantic Segmentation with DeepLabV3+ on Grayscale Images

This notebook demonstrates how to train a semantic segmentation model using DeepLabV3+ on grayscale images. We'll use the `segmentation_models_pytorch` library, which provides pre-implemented models and utilities for semantic segmentation tasks.

## 1. Prerequisites


In [1]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

import segmentation_models_pytorch as smp

## 3. Dataset Preparation

We'll create a custom dataset class to handle grayscale images and their corresponding masks.

In [None]:
class GrayscaleSegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name)
        
        image = Image.open(img_path).convert('L')
        mask = Image.open(mask_path).convert('L')
        
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        
        return image, mask.long()

## 4. DataLoader Setup

Define the transformations and create DataLoader instances for training and validation.

In [None]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# Directories
train_image_dir = 'path/to/train/images'
train_mask_dir = 'path/to/train/masks'
val_image_dir = 'path/to/val/images'
val_mask_dir = 'path/to/val/masks'

# Datasets
train_dataset = GrayscaleSegmentationDataset(train_image_dir, train_mask_dir, transform)
val_dataset = GrayscaleSegmentationDataset(val_image_dir, val_mask_dir, transform)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

## 5. Model Initialization

Initialize the DeepLabV3+ model with `in_channels=1` for grayscale images.

In [None]:
# Specify the number of classes (including background)
num_classes = 2  # Change this based on your dataset

# Initialize model
model = smp.DeepLabV3Plus(
    encoder_name='resnet34',        # Choose encoder, e.g., resnet34
    encoder_weights=None,           # Use None to train from scratch
    in_channels=1,                  # Grayscale images have 1 channel
    classes=num_classes,            # Number of segmentation classes
)

## 6. Training Setup

Define the loss function, optimizer, and learning rate scheduler.

In [None]:
# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

## 7. Training Loop

Implement the training loop.

In [None]:
num_epochs = 25

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks.squeeze(1))

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    # Adjust learning rate
    scheduler.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader):.4f}')

## 8. Validation

Evaluate the model on the validation set.

In [None]:
model.eval()
with torch.no_grad():
    total = 0
    correct = 0
    for images, masks in val_loader:
        images = images.to(device)
        masks = masks.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += masks.nelement()
        correct += (predicted == masks.squeeze(1)).sum().item()

    print(f'Validation Accuracy: {100 * correct / total:.2f}%')

## 9. Saving the Model

Save the trained model for future use.

In [None]:
model_path = 'deeplabv3plus_grayscale.pth'
torch.save(model.state_dict(), model_path)
print(f'Model saved to {model_path}')

## 10. Conclusion

You've successfully trained a DeepLabV3+ semantic segmentation model on grayscale images. You can now use this model for inference or further fine-tuning.

**References:**

- [segmentation_models_pytorch Documentation](https://smp.readthedocs.io/en/latest/models.html#deeplabv3plus)