In [1]:
# Install necessary packages
!pip install tqdm torch torchvision pytorch-msssim

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.models.segmentation import (
    deeplabv3_mobilenet_v3_large,
    DeepLabV3_MobileNet_V3_Large_Weights
)
from tqdm import tqdm
from pathlib import Path
from PIL import Image
import random
import numpy as np
import torch.nn.functional as F
from pytorch_msssim import ssim

# ✅ Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ✅ Paths
DATA_DIR = Path("../data/SD1")
PROCESSED_TRAIN_DIR = DATA_DIR / "processed_train"
PROCESSED_VAL_DIR = DATA_DIR / "processed_val"
MODEL_DIR = Path("../models")
MODEL_DIR.mkdir(parents=True, exist_ok=True)  # Ensure model directory exists

# ✅ Dataset class
class GlareRemovalDataset(Dataset):
    def __init__(self, root_dir: str, transform: transforms.Compose = None, limit: int = None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.image_pairs = []

        # Pair `glare_xxx.png` with `gt_xxx.png`
        glare_images = sorted(self.root_dir.glob("glare_*.png"))
        gt_images = sorted(self.root_dir.glob("gt_*.png"))

        # Ensure correct pairing
        for glare_img in glare_images:
            gt_img = self.root_dir / glare_img.name.replace("glare_", "gt_")
            if gt_img.exists():
                self.image_pairs.append((glare_img, gt_img))

        # Limit dataset size for debugging
        if limit:
            self.image_pairs = random.sample(self.image_pairs, min(limit, len(self.image_pairs)))

    def __len__(self) -> int:
        return len(self.image_pairs)

    def __getitem__(self, idx: int):
        glare_path, gt_path = self.image_pairs[idx]
        glare_img = Image.open(glare_path).convert("RGB")
        gt_img = Image.open(gt_path).convert("RGB")

        if self.transform:
            glare_img = self.transform(glare_img)
            gt_img = self.transform(gt_img)

        return glare_img, gt_img

# ✅ Transformations
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert to tensor (values between 0-1)
])

# ✅ Data loaders (Optimized)
train_dataset = GlareRemovalDataset(PROCESSED_TRAIN_DIR, transform=transform)
val_dataset = GlareRemovalDataset(PROCESSED_VAL_DIR, transform=transform)

train_loader = DataLoader(
    train_dataset, batch_size=16, shuffle=True, num_workers=0, pin_memory=True
)
val_loader = DataLoader(
    val_dataset, batch_size=16, shuffle=False, num_workers=0, pin_memory=True
)

# ✅ Model Definition (DeepLabV3-MobileNetV3)
class GlareRemovalModel(nn.Module):
    def __init__(self):
        super(GlareRemovalModel, self).__init__()
        self.model = deeplabv3_mobilenet_v3_large(weights=DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT)
        self.model.classifier[4] = nn.Conv2d(256, 3, kernel_size=1)  # Modify output channels

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)['out']

# ✅ Instantiate Model
model = GlareRemovalModel().to(device)

# ✅ Custom Loss (L1 + SSIM Loss)
def custom_loss(predicted: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    l1_loss = F.l1_loss(predicted, target)
    ssim_loss = 1 - ssim(predicted, target, data_range=1, size_average=True)
    return l1_loss + 0.5 * ssim_loss  # Adjust SSIM weight as needed

# ✅ Optimizer & Scheduler
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)

# ✅ Training loop
num_epochs = 5  # Increased for better learning
best_val_loss = float("inf")

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, targets in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
        images, targets = images.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(images)

        loss = custom_loss(outputs, targets)
        loss.backward()

        # ✅ Gradient Clipping to avoid exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        running_loss += loss.item()

    avg_train_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1} - Training Loss: {avg_train_loss:.4f}")

    # ✅ Validation loop
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, targets in val_loader:
            images, targets = images.to(device), targets.to(device)
            outputs = model(images)
            val_loss += custom_loss(outputs, targets).item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch {epoch+1} - Validation Loss: {avg_val_loss:.4f}")

    # ✅ Adjust learning rate if needed
    scheduler.step(avg_val_loss)

    # ✅ Save best model checkpoint
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_model_path = MODEL_DIR / "best_glare_removal_model.pth"
        torch.save(model.state_dict(), best_model_path)
        print(f"✅ Best Model Saved at: {best_model_path}")

    # ✅ Periodic checkpoint saving
    if (epoch + 1) % 2 == 0:
        checkpoint_path = MODEL_DIR / f"glare_model_epoch_{epoch + 1}.pth"
        torch.save(model.state_dict(), checkpoint_path)
        print(f"📌 Checkpoint saved at: {checkpoint_path}")

# ✅ Save final model
final_model_path = MODEL_DIR / "final_glare_removal_model.pth"
torch.save(model.state_dict(), final_model_path)
print(f"🎉 Final Model Saved at: {final_model_path}")


Using device: cpu


Epoch 1/5: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 750/750 [1:37:59<00:00,  7.84s/it]


Epoch 1 - Training Loss: 0.2343
Epoch 1 - Validation Loss: 0.1971
✅ Best Model Saved at: ..\models\best_glare_removal_model.pth


Epoch 2/5: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 750/750 [1:32:40<00:00,  7.41s/it]


Epoch 2 - Training Loss: 0.1693
Epoch 2 - Validation Loss: 0.1751
✅ Best Model Saved at: ..\models\best_glare_removal_model.pth
📌 Checkpoint saved at: ..\models\glare_model_epoch_2.pth


Epoch 3/5: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 750/750 [1:31:52<00:00,  7.35s/it]


Epoch 3 - Training Loss: 0.1625
Epoch 3 - Validation Loss: 0.1691
✅ Best Model Saved at: ..\models\best_glare_removal_model.pth


Epoch 4/5: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 750/750 [1:40:20<00:00,  8.03s/it]


Epoch 4 - Training Loss: 0.1583
Epoch 4 - Validation Loss: 0.1740
📌 Checkpoint saved at: ..\models\glare_model_epoch_4.pth


Epoch 5/5: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 750/750 [2:00:28<00:00,  9.64s/it]


Epoch 5 - Training Loss: 0.1565
Epoch 5 - Validation Loss: 0.1660
✅ Best Model Saved at: ..\models\best_glare_removal_model.pth
🎉 Final Model Saved at: ..\models\final_glare_removal_model.pth
