In [1]:
import os
import glob
import random
import pickle
import numpy as np
import cv2

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [2]:
dim = 128  
num_classes = 3  # Trimap {1,2,3}
batch_size = 16
num_epochs = 50

train_images_dir = f'./trainval_{dim}/images'
train_masks_dir  = f'./trainval_{dim}/annotations'

image_paths = sorted(glob.glob(os.path.join(train_images_dir, '*.png')))
mask_paths  = sorted(glob.glob(os.path.join(train_masks_dir, '*.png')))

print(f'Found {len(image_paths)} images and {len(mask_paths)} masks')

Found 3680 images and 3680 masks


In [3]:
class myDataset(Dataset):
    def __init__(self, image_paths, mask_paths, img_size):
        """
        image_paths: list of file paths to input images.
        mask_paths: list of file paths to corresponding trimap masks.
        img_size: desired image size (both width and height).
        """
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.img_size = img_size

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image with OpenCV (BGR -> convert to RGB)
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]
        
        image = cv2.imread(img_path)
        if image is None:
            raise ValueError(f"Image not found at {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (self.img_size, self.img_size))
        image = image.astype(np.float32) / 255.0  
        image = np.transpose(image, (2, 0, 1))
        
        # Load mask in grayscale and resize using nearest-neighbor to preserve labels.
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise ValueError(f"Mask not found at {mask_path}")
        mask = cv2.resize(mask, (self.img_size, self.img_size), interpolation=cv2.INTER_NEAREST)
        # Shift labels: from {1,2,3} to {0,1,2}
        mask = mask - 1
        
        image_tensor = torch.from_numpy(image).float()
        mask_tensor = torch.from_numpy(mask).long()
        
        return image_tensor, mask_tensor
    

In [4]:
train_img_paths, val_img_paths, train_mask_paths, val_mask_paths = train_test_split(
    image_paths, mask_paths, test_size=0.1, random_state=42, shuffle=True
)
print(f"Training images: {len(train_img_paths)}  |  Validation images: {len(val_img_paths)}")

# Create dataset instances.
train_dataset = myDataset(train_img_paths, train_mask_paths, dim)
val_dataset   = myDataset(val_img_paths, val_mask_paths, dim)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

Training images: 3312  |  Validation images: 368


In [5]:
class DoubleConv(nn.Module):
    """(Convolution => BatchNorm => ReLU) * 2"""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, num_classes):
        super(UNet, self).__init__()
        self.inc = DoubleConv(3, 16)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(16, 32))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(32, 64))
        self.down3 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(64, 128))
        self.down4 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(128, 256))
        
        self.up1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv1 = DoubleConv(256, 128)
        
        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv2 = DoubleConv(128, 64)
        
        self.up3 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.conv3 = DoubleConv(64, 32)
        
        self.up4 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
        self.conv4 = DoubleConv(32, 16)
        
        self.outc = nn.Conv2d(16, num_classes, kernel_size=1)
    
    def forward(self, x):
        x1 = self.inc(x)       # (B, 16, H, W)
        x2 = self.down1(x1)    # (B, 32, H/2, W/2)
        x3 = self.down2(x2)    # (B, 64, H/4, W/4)
        x4 = self.down3(x3)    # (B, 128, H/8, W/8)
        x5 = self.down4(x4)    # (B, 256, H/16, W/16)
        
        u1 = self.up1(x5)      # (B, 128, H/8, W/8)
        # Concatenate skip connection from x4.
        u1 = torch.cat([u1, x4], dim=1)
        u1 = self.conv1(u1)
        
        u2 = self.up2(u1)      # (B, 64, H/4, W/4)
        u2 = torch.cat([u2, x3], dim=1)
        u2 = self.conv2(u2)
        
        u3 = self.up3(u2)      # (B, 32, H/2, W/2)
        u3 = torch.cat([u3, x2], dim=1)
        u3 = self.conv3(u3)
        
        u4 = self.up4(u3)      # (B, 16, H, W)
        u4 = torch.cat([u4, x1], dim=1)
        u4 = self.conv4(u4)
        
        logits = self.outc(u4)  # (B, num_classes, H, W)
        return logits

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(num_classes=num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()  # expects logits [B, C, H, W] and target [B, H, W]

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    epoch_loss = 0.0
    for images, masks in loader:
        images = images.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)  # (B, num_classes, H, W)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item() * images.size(0)
    return epoch_loss / len(loader.dataset)

def evaluate(model, loader, criterion, device):
    model.eval()
    epoch_loss = 0.0
    correct = 0
    total_pixels = 0
    with torch.no_grad():
        for images, masks in loader:
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            epoch_loss += loss.item() * images.size(0)
            
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == masks).sum().item()
            total_pixels += torch.numel(masks)
    accuracy = correct / total_pixels
    return epoch_loss / len(loader.dataset), accuracy

In [None]:
best_val_loss = float('inf')
for epoch in range(1, num_epochs+1):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)
    
    print(f"Epoch [{epoch}/{num_epochs}]  Train Loss: {train_loss:.4f}  "
          f"Val Loss: {val_loss:.4f}  Val Pixel Acc: {val_acc*100:.2f}%")
    
    # Save the best model.
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), f"unet_model_{dim}_epochs_{num_epochs}.pth")

# Save the trained model.
torch.save(model.state_dict(), f"unet_model_{dim}_epochs_{num_epochs}.pth")

In [None]:
# Predict on test images and save the results.
import pickle

#create output folder

if not os.path.exists('output_unet_test'):
    os.makedirs('output_unet_test')

# Load the best model.
model.load_state_dict(torch.load("unet_model_128_epochs_50.pth"))


with open('test_images_paths.pkl', 'rb') as f:
    test_images = pickle.load(f)

with open('test_trimap_paths.pkl', 'rb') as f:
    test_trimap = pickle.load(f)

for i in range(len(test_images)):

    image = cv2.imread(test_images[i])
    trimap = cv2.imread(test_trimap[i], interpolation=cv2.INTER_NEAREST)
    trimap = trimap - 1

    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (dim, dim))
    image = image.astype(np.float32) / 255.0
    image = np.transpose(image, (2, 0, 1))

    image_tensor = torch.from_numpy(image).float().unsqueeze(0)

    image_tensor = image_tensor.to(device)


    with torch.no_grad():
        output = model(image_tensor)
        output = torch.argmax(output, dim=1).squeeze().cpu().numpy()

    output = output + 1
    output = output.astype(np.uint8)

    output = cv2.resize(output, (trimap.shape[1], trimap.shape[0]), interpolation=cv2.INTER_NEAREST)

    cv2.imwrite(f'output_unet_test/output_{i}.png', output)