In [None]:
# Install necessary packages
!pip install tqdm torch torchvision

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.models.segmentation import deeplabv3_mobilenet_v3_large, DeepLabV3_MobileNet_V3_Large_Weights
from tqdm import tqdm
from pathlib import Path
from PIL import Image
import random

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Directory paths for SD1 dataset
DATA_DIR = Path("../data/SD1")
PROCESSED_TRAIN_DIR = DATA_DIR / "processed_train"
PROCESSED_VAL_DIR = DATA_DIR / "processed_val"

# Custom dataset class for paired images
class GlareRemovalDataset(Dataset):
    def __init__(self, root_dir, transform=None, limit=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.image_pairs = []

        # Get all `glare_xxx.png` and match them with `gt_xxx.png`
        glare_images = sorted(self.root_dir.glob("glare_*.png"))
        gt_images = sorted(self.root_dir.glob("gt_*.png"))

        # Ensure paired dataset
        for glare_img in glare_images:
            gt_img = self.root_dir / glare_img.name.replace("glare_", "gt_")
            if gt_img.exists():
                self.image_pairs.append((glare_img, gt_img))

        # Limit dataset size for debugging
        if limit:
            self.image_pairs = random.sample(self.image_pairs, min(limit, len(self.image_pairs)))

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        glare_path, gt_path = self.image_pairs[idx]
        glare_img = Image.open(glare_path).convert("RGB")
        gt_img = Image.open(gt_path).convert("RGB")

        if self.transform:
            glare_img = self.transform(glare_img)
            gt_img = self.transform(gt_img)

        return glare_img, gt_img

# Transformations (No need for resizing since images are already 512x512)
transform = transforms.Compose([
    transforms.ToTensor(),          # Convert to PyTorch tensor
])

train_dataset = GlareRemovalDataset(PROCESSED_TRAIN_DIR, transform=transform)
val_dataset = GlareRemovalDataset(PROCESSED_VAL_DIR, transform=transform)

# Increase Batch Size (If Memory Allows) to process more images at once
# Optimize DataLoader (num_workers > 0)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=0)

# Define a smaller model (DeepLabV3-MobileNetV3)
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.model = deeplabv3_mobilenet_v3_large(weights=DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT)

        # Change the last layer to output 3 channels instead of 21
        self.model.classifier[4] = nn.Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))
        
    def forward(self, x):
        return self.model(x)['out']

# Instantiate the model and move to device
model = UNet().to(device)

# Define loss function and optimizer
criterion = nn.L1Loss()  
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Training loop
num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, targets in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
        images, targets = images.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(images)

        # Ensure outputs match target size (output is already 512x512)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {running_loss / len(train_loader)}")

    # Save model checkpoint
    if (epoch + 1) % 2 == 0:
        checkpoint_path = f"../models/unet__mobilenet_epoch_{epoch + 1}.pth"
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

# Save final model
final_model_path = "../models/final_unet_mobilenet.pth"
torch.save(model.state_dict(), final_model_path)
print(f"Final model saved at {final_model_path}")



Using device: cpu


Epoch 1/3:   0%|                                                                                                                                                                                      | 0/750 [00:00<?, ?it/s]