In [3]:
# Install necessary packages
!pip install torch torchvision matplotlib tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.models.segmentation import deeplabv3_mobilenet_v3_large
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from tqdm import tqdm

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define dataset paths
DATA_DIR = Path("../data/SD1")
PROCESSED_VAL_DIR = DATA_DIR / "processed_val"
MODEL_PATH = Path("../models/final_unet_mobilenet.pth")

# Define Dataset Class
class GlareRemovalDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.image_pairs = []

        # Collect glare/gt image pairs
        glare_images = sorted(self.root_dir.glob("glare_*.png"))
        gt_images = sorted(self.root_dir.glob("gt_*.png"))

        for glare_img in glare_images:
            gt_img = self.root_dir / glare_img.name.replace("glare_", "gt_")
            if gt_img.exists():
                self.image_pairs.append((glare_img, gt_img))

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        glare_path, gt_path = self.image_pairs[idx]
        glare_img = Image.open(glare_path).convert("RGB")
        gt_img = Image.open(gt_path).convert("RGB")

        if self.transform:
            glare_img = self.transform(glare_img)
            gt_img = self.transform(gt_img)

        return glare_img, gt_img

# Define Transforms
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Load validation dataset
val_dataset = GlareRemovalDataset(PROCESSED_VAL_DIR, transform=transform)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2)

# Define Model
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.model = deeplabv3_mobilenet_v3_large(weights=None)  # No pretrained weights
        self.model.classifier[4] = nn.Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))

    def forward(self, x):
        return self.model(x)['out']

# Load trained model
model = UNet().to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()  # Set to evaluation mode
print("Model loaded successfully.")

# Define loss function
criterion = nn.L1Loss()

# Evaluate model on validation dataset
total_loss = 0.0
with torch.no_grad():
    for images, targets in tqdm(val_loader, desc="Evaluating"):
        images, targets = images.to(device), targets.to(device)
        outputs = model(images)
        loss = criterion(outputs, targets)
        total_loss += loss.item()

# Print average L1 loss
avg_loss = total_loss / len(val_loader)
print(f"Average L1 Loss on Validation Set: {avg_loss:.4f}")

# Visualize some predictions
def show_images(original, predicted, ground_truth):
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    axes[0].imshow(original.permute(1, 2, 0))
    axes[0].set_title("Input (Glare)")
    axes[1].imshow(predicted.permute(1, 2, 0).clip(0, 1))  # Clamp values
    axes[1].set_title("Predicted (De-glared)")
    axes[2].imshow(ground_truth.permute(1, 2, 0))
    axes[2].set_title("Ground Truth")
    plt.show()

# Display first few images
for i, (images, targets) in enumerate(val_loader):
    if i == 3:  # Show 3 batches
        break
    images, targets = images.to(device), targets.to(device)
    with torch.no_grad():
        outputs = model(images)
    for j in range(len(images)):
        show_images(images[j].cpu(), outputs[j].cpu(), targets[j].cpu())


Using device: cpu


RuntimeError: Error(s) in loading state_dict for UNet:
	Unexpected key(s) in state_dict: "model.aux_classifier.0.weight", "model.aux_classifier.1.weight", "model.aux_classifier.1.bias", "model.aux_classifier.1.running_mean", "model.aux_classifier.1.running_var", "model.aux_classifier.1.num_batches_tracked", "model.aux_classifier.4.weight", "model.aux_classifier.4.bias". 