<a href="https://colab.research.google.com/github/MehrdadDastouri/unet_image_segmentation/blob/main/unet_image_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# U-Net model definition
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(UNet, self).__init__()
        # Encoder
        self.conv1 = self.double_conv(in_channels, 64)
        self.conv2 = self.double_conv(64, 128)
        self.conv3 = self.double_conv(128, 256)
        self.conv4 = self.double_conv(256, 512)

        # Bottleneck
        self.bottleneck = self.double_conv(512, 1024)

        # Decoder
        self.upconv4 = self.upconv(1024, 512)
        self.dec4 = self.double_conv(1024, 512)
        self.upconv3 = self.upconv(512, 256)
        self.dec3 = self.double_conv(512, 256)
        self.upconv2 = self.upconv(256, 128)
        self.dec2 = self.double_conv(256, 128)
        self.upconv1 = self.upconv(128, 64)
        self.dec1 = self.double_conv(128, 64)

        # Final layer
        self.final = nn.Conv2d(64, out_channels, kernel_size=1)

    def double_conv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def upconv(self, in_channels, out_channels):
        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        # Encoder
        c1 = self.conv1(x)
        p1 = nn.MaxPool2d(2)(c1)
        c2 = self.conv2(p1)
        p2 = nn.MaxPool2d(2)(c2)
        c3 = self.conv3(p2)
        p3 = nn.MaxPool2d(2)(c3)
        c4 = self.conv4(p3)
        p4 = nn.MaxPool2d(2)(c4)

        # Bottleneck
        bn = self.bottleneck(p4)

        # Decoder
        u4 = self.upconv4(bn)
        d4 = self.dec4(torch.cat([u4, c4], dim=1))
        u3 = self.upconv3(d4)
        d3 = self.dec3(torch.cat([u3, c3], dim=1))
        u2 = self.upconv2(d3)
        d2 = self.dec2(torch.cat([u2, c2], dim=1))
        u1 = self.upconv1(d2)
        d1 = self.dec1(torch.cat([u1, c1], dim=1))

        return self.final(d1)

# Dataset class for image segmentation
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx])
        image = np.array(Image.open(img_path).convert("L"))  # Convert to grayscale
        mask = np.array(Image.open(mask_path).convert("L"))  # Convert to grayscale

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"]

        return image, mask

# Data augmentation and preprocessing
from torchvision.transforms import functional as TF

class Transform:
    def __call__(self, image, mask):
        # Resize
        image = TF.resize(image, (128, 128))
        mask = TF.resize(mask, (128, 128))

        # Convert to tensor
        image = TF.to_tensor(image)
        mask = TF.to_tensor(mask)

        return {"image": image, "mask": mask}

# Define dataset and data loaders
train_transform = Transform()
train_dataset = SegmentationDataset("data/images", "data/masks", transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

# Initialize the model
model = UNet(in_channels=1, out_channels=1).to(device)

# Loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 10
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks)
        epoch_loss += loss.item()

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss / len(train_loader):.4f}")

# Visualize a sample result
model.eval()
with torch.no_grad():
    sample_image, sample_mask = train_dataset[0]
    sample_image = sample_image.unsqueeze(0).to(device)
    predicted_mask = torch.sigmoid(model(sample_image))
    predicted_mask = (predicted_mask > 0.5).float()

# Plot original image, ground truth mask, and predicted mask
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.title("Original Image")
plt.imshow(sample_image.cpu().squeeze(), cmap="gray")
plt.subplot(1, 3, 2)
plt.title("Ground Truth Mask")
plt.imshow(sample_mask.cpu().squeeze(), cmap="gray")
plt.subplot(1, 3, 3)
plt.title("Predicted Mask")
plt.imshow(predicted_mask.cpu().squeeze(), cmap="gray")
plt.show()