# Dust Removal for Scanned Film Photography

This notebook demonstrates training a U-Net model to detect dust on scanned film photographs using your own dataset.

## Dataset Creation

Create a PyTorch dataset that loads your clean and dusty image pairs.

In [6]:
from torch.utils.data import Dataset
import cv2
import numpy as np
import os
from torchvision import transforms

class DustDataset(Dataset):
    def __init__(self, clean_dir, dusty_dir, image_size=(1024, 1024), transform=None):
        self.clean_dir = clean_dir
        self.dusty_dir = dusty_dir
        self.image_size = image_size
        self.transform = transform

        self.clean_image_paths = sorted([os.path.join(clean_dir, f) for f in os.listdir(clean_dir)
                                         if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        
        self.dusty_image_paths = []
        for clean_path in self.clean_image_paths:
            clean_filename = os.path.basename(clean_path)
            dusty_path = os.path.join(dusty_dir, clean_filename)
            if os.path.exists(dusty_path):
                self.dusty_image_paths.append(dusty_path)
            else:
                raise FileNotFoundError(f"Could not find corresponding dusty image for {clean_path} at {dusty_path}")

        if len(self.clean_image_paths) != len(self.dusty_image_paths):
            raise ValueError("Number of clean images and dusty images must be the same.")

    def __len__(self):
        return len(self.clean_image_paths)

    def __getitem__(self, idx):
        clean_path = self.clean_image_paths[idx]
        dusty_path = self.dusty_image_paths[idx]

        clean_img = cv2.imread(clean_path)
        dusty_img = cv2.imread(dusty_path)

        # Convert to greyscale
        if len(clean_img.shape) == 3:
            clean_img = cv2.cvtColor(clean_img, cv2.COLOR_BGR2GRAY)
        if len(dusty_img.shape) == 3:
            dusty_img = cv2.cvtColor(dusty_img, cv2.COLOR_BGR2GRAY)
        
        # Resize images
        clean_img = cv2.resize(clean_img, self.image_size)
        dusty_img = cv2.resize(dusty_img, self.image_size)

        # Calculate mask: difference between dusty and clean images
        # Normalize to [0, 1] and ensure it's float32
        mask = np.abs(dusty_img.astype(np.float32) - clean_img.astype(np.float32)) / 255.0
        
        # Normalize and prepare inputs - convert greyscale to single channel format
        image = dusty_img.astype(np.float32) / 255.0
        image = image[..., np.newaxis]  # Add channel dimension: HW -> HW1
        mask = mask[..., np.newaxis]  # HW -> HW1, already in [0, 1] range

        # Apply augmentation (optional)
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # CHW format
        image = np.transpose(image, (2, 0, 1))
        mask = np.transpose(mask, (2, 0, 1))

        return image, mask

### Data Augmentation

Define augmentations specific to film photography including rotations and shifts that simulate scanning variations.

In [7]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Augmentations for greyscale images (removing RGB-only transforms)
transform = A.Compose([
    # Geometric transforms - films can be scanned at slight angles
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.3),
    A.Rotate(limit=5, p=0.4),  # Small rotations for scanning variations
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=3, p=0.3)
])

## U-Net Model Architecture

Implement a U-Net for semantic segmentation adapted for grayscale dust detection. The encoder-decoder structure with skip connections preserves fine details while building semantic understanding.

In [8]:
import torch
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1), nn.ReLU(),
                nn.Conv2d(out_c, out_c, 3, padding=1), nn.ReLU()
            )

        # Back to 1 channel for greyscale
        self.enc1 = conv_block(1, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)

        self.pool = nn.MaxPool2d(2)

        self.middle = conv_block(512, 1024)

        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = conv_block(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = conv_block(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = conv_block(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = conv_block(128, 64)

        self.final = nn.Conv2d(64, 1, 1)

    def forward(self, x):
        e1 = self.enc1(x)           # 512
        e2 = self.enc2(self.pool(e1))  # 256
        e3 = self.enc3(self.pool(e2))  # 128
        e4 = self.enc4(self.pool(e3))  # 64

        m = self.middle(self.pool(e4))  # 32

        d4 = self.dec4(torch.cat([self.up4(m), e4], dim=1))  # 64
        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1)) # 128
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1)) # 256
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1)) # 512

        return torch.sigmoid(self.final(d1))  # 1x512x512

## Model Training

Train the U-Net model on 1024x1024 images with reduced batch size to fit memory constraints. Lower learning rate compensates for larger image resolution.

In [9]:
from torch.utils.data import DataLoader
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
dataset = DustDataset("user_dataset/clean", "user_dataset/dusty", image_size=(512, 512), transform=transform)  # 512x512
loader = DataLoader(dataset, batch_size=4, shuffle=True)  # Reduced batch size for 512x512

model = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)  # Lower learning rate for larger images
criterion = nn.BCELoss()



In [10]:

os.makedirs("checkpoints", exist_ok=True)
for epoch in range(20):
    model.train()
    total_loss = 0
    for images, masks in tqdm(loader):
        images = images.to(device)
        masks = masks.to(device)

        preds = model(images)
        loss = criterion(preds, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")

    # Save weights for each epoch
    torch.save(model.state_dict(), f"checkpoints/v6_bce_unet_epoch{epoch+1:02d}.pth")

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.14s/it]


Epoch 1: Loss = 0.6990


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.17s/it]


Epoch 2: Loss = 0.6957


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.29s/it]


Epoch 3: Loss = 0.6920


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.32s/it]


Epoch 4: Loss = 0.6884


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.55s/it]


Epoch 5: Loss = 0.6849


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.17s/it]


Epoch 6: Loss = 0.6812


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.66s/it]


Epoch 7: Loss = 0.6772


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:06<00:00,  6.12s/it]


Epoch 8: Loss = 0.6725


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:07<00:00,  7.13s/it]


Epoch 9: Loss = 0.6678


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:07<00:00,  7.15s/it]


Epoch 10: Loss = 0.6630


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:06<00:00,  6.35s/it]


Epoch 11: Loss = 0.6582


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.72s/it]


Epoch 12: Loss = 0.6520


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.53s/it]


Epoch 13: Loss = 0.6456


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.36s/it]


Epoch 14: Loss = 0.6381


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.30s/it]


Epoch 15: Loss = 0.6288


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.33s/it]


Epoch 16: Loss = 0.6183


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.40s/it]


Epoch 17: Loss = 0.6046


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.37s/it]


Epoch 18: Loss = 0.5861


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.68s/it]


Epoch 19: Loss = 0.5607


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:06<00:00,  6.38s/it]

Epoch 20: Loss = 0.5248





In [11]:
from PIL import Image

def predict_dust_mask(model, weights_path, image_path, threshold=0.5, window_size=1024, stride=512, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))  # Prefer CUDA, then MPS, else CPU

    # Load model
    model.load_state_dict(torch.load(weights_path, map_location=device))
    model = model.to(device)
    model.eval()

    # Load image
    image = Image.open(image_path).convert('L')
    image_np = np.array(image).astype(np.float32)
    H, W = image_np.shape

    # Pad image to cover full area with sliding windows
    pad_h = (stride - ((H - window_size) % stride)) % stride if H >= window_size else (window_size - H)
    pad_w = (stride - ((W - window_size) % stride)) % stride if W >= window_size else (window_size - W)
    padded = np.pad(image_np, ((0, pad_h), (0, pad_w)), mode='reflect')
    pH, pW = padded.shape

    prediction_map = np.zeros((pH, pW), dtype=np.float32)
    count_map = np.zeros((pH, pW), dtype=np.float32)

    with torch.no_grad():
        for y in range(0, pH - window_size + 1, stride):
            for x in range(0, pW - window_size + 1, stride):
                patch = padded[y:y+window_size, x:x+window_size]
                patch_tensor = torch.from_numpy(patch).float().unsqueeze(0).unsqueeze(0) / 255.0
                patch_tensor = patch_tensor.to(device)
                pred = model(patch_tensor)
                pred_np = pred.squeeze().cpu().numpy()
                prediction_map[y:y+window_size, x:x+window_size] += pred_np
                count_map[y:y+window_size, x:x+window_size] += 1.0

    final_mask = prediction_map / np.maximum(count_map, 1e-8)
    final_mask = final_mask[:H, :W]

    return final_mask

## Inference Pipeline

Implement patch-based inference for processing large images. Uses overlapping patches with weighted averaging for seamless reconstruction.

In [12]:
image_path = "user_dataset/dusty/1.jpg"
weights_path = "checkpoints/v6_bce_unet_epoch20.pth"
model = UNet()
dust = predict_dust_mask(model, weights_path, image_path, threshold=0.2, window_size=1024, stride=980)

KeyboardInterrupt: 

### Testing Dust Detection

Load a trained model and test dust detection on a sample image. Adjust threshold to balance detection sensitivity.

In [None]:
import skimage.io as skio
import matplotlib.pyplot as plt
plt.imshow( dust>0.01, cmap='gray')