# Dataset

**Dataset Structure**:

- Cloth-Segmentation-Dataset/
- ├── train_images/   ← Training input images
- ├── train_masks/    ← Ground truth masks for training
- ├── val_images/     ← Validation input images
- └── val_masks/      ← Ground truth masks for validation

Each image in the *_images folders has a corresponding mask in the matching *_masks folder with the same file name.

In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np

In [None]:
class ClothDataset(Dataset):
  def __init__(self, img_dir, mask_dir, transform=None):
    self.img_dir = img_dir
    self.mask_dir = mask_dir
    self.transform = transform

    self.images = os.listdir(img_dir)

  def __len__(self):
    return len(self.images)

  def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.images[idx])
    mask_path = os.path.join(self.mask_dir, self.images[idx])
    image = np.array(Image.open(img_path).convert("RGB"))
    mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
    mask[mask == 255.0] = 1.0

    if self.transform is not None:
      augmentations = self.transform(image=image, mask=mask)
      image = augmentations["image"]
      mask = augmentations["mask"]

    return image, mask

# Model

**UNET Architecture:**

U-Net is a symmetric encoder-decoder network:

- Down path (encoder): captures context via convolution + pooling.

- Up path (decoder): reconstructs image spatially using transposed convolutions and skip connections.

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

**nn.Conv2d(..., bias=False)**

bias=False: this means no bias term is added in the convolution.

Here's why:

- The bias in the convolution becomes redundant because BatchNorm has its own learnable affine parameters (gamma and beta) that effectively shift and scale the output.

- Removing the bias saves memory and computation (a small optimization).

- In practice, it makes no difference in performance but is more efficient.

---

**nn.BatchNorm2d()**

Batch Normalization normalizes each feature (channel) across the batch to have:

- Mean = 0

- Standard deviation = 1

This helps stabilize training, but too much normalization can limit the model's expressiveness. So we add back learnable parameters:

- γ (gamma): a scale factor

- β (beta): a shift (bias) factor

**Mathematically:**

Let’s say the output from a conv layer is x.

BatchNorm does:


$$
x_{\text{normalized}} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}}
$$

$$
\text{output} = \gamma \cdot x_{\text{normalized}} + \beta
$$


- mean and std are computed per feature map (channel) across the batch.

- γ and β are learnable (optimized via backpropagation).

Why?

- Sometimes you don’t want purely normalized values.

- γ and β let the model recover the original distribution if that's what helps the task.

- This gives the network flexibility to decide how much normalization is useful for each feature.

**Visual Analogy:**

Imagine trying to hit a moving target — that’s training without BatchNorm. Every time you aim (adjust weights), the target (data distribution) moves.

BatchNorm holds the target steady — so your aim improves faster and more accurately.


In [None]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(DoubleConv, self).__init__()

    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),

        nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )

  def forward(self, x):
    return self.conv(x)

UNet Architecture:

- in_channels: number of input channels (e.g. RGB = 3).

- out_channels: number of output channels (e.g. segmentation mask = 1).

- features: controls the width (filters) at each layer. We downsample then upsample through these.

- Pooling Used after each encoder block.

- After upsampling, concatenate with encoder output (skip connection), then pass through DoubleConv(feature*2, feature).

- Bottleneck: This is the deepest layer between down and up paths.

- Final Output Layer: After decoding is complete, this reduces the number of channels to out_channels (e.g. 1 for binary mask).

- In downsampling, we use Pooling which floors the size. So, in upsampling, we better check the sizes before concatinating to prevent errors.



In [None]:
class UNET(nn.Module):
  def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
    super(UNET, self).__init__()

    # To store the encoder and decoder blocks
    self.ups = nn.ModuleList()
    self.downs = nn.ModuleList()
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    # Down Part of UNET
    for feature in features:
      self.downs.append(DoubleConv(in_channels, feature))
      in_channels = feature

    # Up Part of UNET
    for feature in reversed(features):
      self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))

      self.ups.append(DoubleConv(feature*2, feature))

    self.bottleneck = DoubleConv(features[-1], features[-1]*2)
    self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

  def forward(self, x):
    skip_connections = []

    # Downsampling
    for down in self.downs:
      x = down(x)
      skip_connections.append(x)
      x = self.pool(x)

    # Bottleneck
    x = self.bottleneck(x)

    skip_connections = skip_connections[::-1]  # Reverse the skip connections

    # Upsampling
    for idx in range(0, len(self.ups), 2):
      # Going Up
      x = self.ups[idx](x)

      skip_connection = skip_connections[idx//2]

      # Compare size of x in downsampling vs upsampling
      if x.shape != skip_connection.shape:
        x = TF.resize(x, size=skip_connection.shape[2:])

      concat_skip = torch.cat((skip_connection, x), dim=1)  # dim=1 to concat through channel dimension - [batch, channel, height, width]

      # DoubleConv
      x = self.ups[idx+1](concat_skip)

    return self.final_conv(x)

Test Model

In [None]:
x = torch.randn((3, 1, 160, 160))
model = UNET(1, 1)
preds = model(x)
print(x.shape)  # torch.Size([3, 1, 160, 160])
print(preds.shape)  # torch.Size([3, 1, 160, 160])
assert preds.shape == x.shape

# Training

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm

**Hyperparameters**

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
LEARNING_RATE = 1e-4
SCHEDULER_STEP_SIZE = 10
BATCH_SIZE = 16
NUM_EPOCHS = 25
NUM_WORKERS = 4
PIN_MEMORY = True
IMAGE_HEIGHT = 160  # Adjust to your dataset
IMAGE_WIDTH = 240
LOAD_MODEL = False
CHECKPOINT_PATH = "unet_checkpoint.pth"
TRAIN_IMAGE_DIR = 'Cloth-Segmentation-Dataset/train_images'
TRAIN_MASK_DIR = 'Cloth-Segmentation-Dataset/train_masks'
VAL_IMAGE_DIR = 'Cloth-Segmentation-Dataset/val_images'
VAL_MASK_DIR = 'Cloth-Segmentation-Dataset/val_masks'

**Transformation**

In [None]:
transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0), max_pixel_value=255.0),
    ToTensorV2(),
])

**Dataset and Dataloader**

In [None]:
train_dataset = ClothDataset(
    img_dir=TRAIN_IMAGE_DIR,
    mask_dir=TRAIN_MASK_DIR,
    transform=transform,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,  # Automatically moves data to pinned memory (faster GPU transfer)
    shuffle=True,
)

**Initialize model, loss, and optimizer**

BCEWithLogitsLoss is best for binary segmentation, and combines a Sigmoid + BCELoss (Binary Cross Entropy).

**Learning rate scheduler** adjusts the learning rate during training, instead of keeping it fixed.

**StepLR**:

- Reduces the LR every step_size epochs (e.g., every 10 epochs).

- Multiplies LR by gamma (e.g., 0.1 → divide by 10).

This setup:

- Starts at 1e-4

- At epoch 10: LR → 5e-5

- At epoch 20: LR → 2.5e-5

- ... and so on

In [None]:
model = UNET(in_channels=3, out_channels=1).to(device)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

In [None]:
# Optional: Load checkpoint
if LOAD_MODEL:
    checkpoint = torch.load(CHECKPOINT_PATH)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

**Save Model**

In [None]:
def save_checkpoint(model, optimizer, filename=CHECKPOINT_PATH):
    print("=> Saving checkpoint")
    torch.save({
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }, filename)

**Training function**

In [None]:
for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    model.train()
    running_loss = 0.0

    for data, targets in tqdm(train_loader, desc='Training loop'):
        data = data.to(device)
        targets = targets.float().unsqueeze(1).to(device)  # Add channel dim -> unsqueeze(1) changes shape from [B, H, W] to [B, 1, H, W].

        # Forward
        predictions = model(data)
        loss = loss_fn(predictions, targets)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    scheduler.step()

    print(f'Average Loss: {running_loss / len(train_loader):.4f}')


    for param_group in optimizer.param_groups:
    print(f"Current Learning Rate: {param_group['lr']}")

    if epoch % 5 == 0:
      save_checkpoint(model, optimizer)


# Evaluating

**Save Predictions**

In [None]:
from PIL import Image
import os

def save_predictions_as_images(preds, targets, batch_idx, output_dir="saved_preds"):
    os.makedirs(output_dir, exist_ok=True)

    for i in range(preds.shape[0]):  # Loop through batch
        pred_mask = preds[i].squeeze().cpu().numpy() * 255.0  # [1, H, W] -> [H, W]
        true_mask = targets[i].squeeze().cpu().numpy() * 255.0

        pred_img = Image.fromarray(pred_mask.astype("uint8"))
        target_img = Image.fromarray(true_mask.astype("uint8"))

        pred_img.save(os.path.join(output_dir, f"pred_{batch_idx}_{i}.png"))
        target_img.save(os.path.join(output_dir, f"gt_{batch_idx}_{i}.png"))

**Define a Dice Score Function**

Dice score measures how well the predicted mask overlaps with the ground truth mask. Values range from 0 (no overlap) to 1 (perfect match).

In [None]:
def dice_score(preds, targets, eps=1e-8):
    intersection = (preds * targets).sum()
    union = preds.sum() + targets.sum()

    dice = (2. * intersection) / (union + 1e-8)  # Avoid division by zero
    return dice.item()

**Validation Function**

In [None]:
model.eval()

num_correct = num_pixels = 0
total_dice = 0.0

with torch.no_grad():
  for batch_idx, (data, targets) in enumerate(tqdm(val_loader, desc='Validation loop')):

    data = data.to(device)
    targets = targets.float().unsqueeze(1).to(device)

    preds = torch.sigmoid(model(data))     # Convert logits to probabilities
    preds = (preds > 0.5).float()           # Binarize predictions

    num_correct += (preds == targets).sum()
    num_pixels += torch.numel(preds)

    total_dice += dice_score(preds, targets)

    # Save predictions for first few batches only
    if batch_idx < 5:
        save_predictions_as_images(preds, targets, batch_idx)

accuracy = num_correct / num_pixels * 100
avg_dice = total_dice / len(val_loader)

print(f'Pixel Accuracy: {accuracy:.2f}% | Avg Dice Score: {avg_dice:.4f}')