In [1]:
import os
import numpy as np
from PIL import Image
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms.functional as TF


In [None]:

class ForegroundMaskDataset(Dataset):
    def __init__(self, mnist_data, mask_dir, transform=None):
        self.mnist_data = mnist_data
        self.mask_dir = mask_dir
        self.transform = transform

    def __len__(self):
        return len(self.mnist_data)

    def __getitem__(self, idx):
        img, _ = self.mnist_data[idx]
        mask_path = os.path.join(self.mask_dir, f"{idx:05d}_fgmask.png")
        mask = Image.open(mask_path).convert("L")

        if self.transform:
            img = self.transform(img)
            mask = self.transform(mask)

        mask = (mask > 0).float()
        return img, mask

In [3]:
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
])

In [None]:

mnist_dataset = MNIST(root="./data", train=True, download=False)
dataset = ForegroundMaskDataset(mnist_dataset, "output/foreground_masks", transform=transform)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_ds, test_ds = random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)

In [None]:
class SimpleSegNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1), 
            nn.ReLU(),
            nn.MaxPool2d(2),                # (B,16,14,14)
            nn.Conv2d(16, 32, 3, padding=1),# (B,32,14,14)
            nn.ReLU(),
            nn.MaxPool2d(2),                # (B,32,7,7)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 2, stride=2), # (B,16,14,14)
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 2, stride=2),  # (B,1,28,28)
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [6]:
def compute_iou(preds, targets, threshold=0.5):
    preds = (preds > threshold).float()
    intersection = (preds * targets).sum(dim=(1, 2, 3))
    union = ((preds + targets) > 0).float().sum(dim=(1, 2, 3))
    iou = (intersection / union).mean()
    return iou.item()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleSegNet().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
epochs = 5


In [8]:
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"[Epoch {epoch+1}] Loss: {running_loss / len(train_loader):.4f}")


[Epoch 1] Loss: 0.1093
[Epoch 2] Loss: 0.0398
[Epoch 3] Loss: 0.0351
[Epoch 4] Loss: 0.0328
[Epoch 5] Loss: 0.0312


In [None]:

model.eval()
ious = []
with torch.no_grad():
    for imgs, masks in test_loader:
        imgs, masks = imgs.to(device), masks.to(device)
        outputs = model(imgs)
        iou = compute_iou(outputs, masks)
        ious.append(iou)

mean_iou = np.mean(ious)
print(f"[✓] Mean IoU on test set: {mean_iou:.4f}")


[✓] Mean IoU on test set: 0.9124
