In [None]:
TASK 2: Image Deblurring Using U-Net architecture

Objective: The objective of this task is to develop a deep learning model to perform image deblurring and denoising using the UNet architecture. The provided dataset consists of triplets of images for each scene: sharp, defocused-blurred, and motion-blurred images. The model should be capable of taking a blurred image as input and generating a high-quality, sharp, and noise-free image as output.

PS: The performance of the model will be evaluated based on the L2 norm distance between the predicted sharp image and the ground truth sharp image on both the training and test sets. Build and train a UNet-based model to deblur and denoise the defocused-blurred and motion-blurred images, aiming to restore them to their corresponding sharp versions. Submission should be only in PyTorch framework.

Deadline: As the secy tasks have been released, deadline for this task is 25th May EOD.

Link To the Dataset: https://www.kaggle.com/datasets/kwentar/blur-dataset

Understanding UNet: https://paperswithcode.com/method/u-net

In [1]:
import pickle
import cv2
import os
import torch
import torch.nn as nn
from torchvision import models
from torch.nn.functional import relu
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

In [2]:
sharp=[]
defocused_blurred=[]
motion_blurred=[]

In [3]:
sharp_images=sorted(os.listdir("archive/sharp"))
defocused_blurred_images=sorted(os.listdir("archive/defocused_blurred"))
motion_blurred_images=sorted(os.listdir("archive/motion_blurred"))

In [4]:
for img in sharp_images:
    img=cv2.imread("archive/sharp/"+img)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 
    if img.shape[1]>img.shape[0]:
        img=cv2.transpose(img)
    img=cv2.resize(img,(352,528))
    img=np.array(img)
    sharp.append(img)

In [5]:
for img in defocused_blurred_images:
    img=cv2.imread("archive/defocused_blurred/"+img)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 
    if img.shape[1]>img.shape[0]:
        img=cv2.transpose(img)
    img=cv2.resize(img,(352,528))
    img=np.array(img)
    defocused_blurred.append(img)

In [6]:
for img in motion_blurred_images:
    img=cv2.imread("archive/motion_blurred/"+img)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 
    if img.shape[1]>img.shape[0]:
        img=cv2.transpose(img)
    img=cv2.resize(img,(352,528))
    img=np.array(img)
    motion_blurred.append(img)

In [7]:
print(len(sharp),len(motion_blurred),len(defocused_blurred))

350 350 350


In [8]:
sharp=np.array(sharp)
defocused_blurred=np.array(defocused_blurred)
motion_blurred=np.array(motion_blurred)

In [9]:
sharp=sharp/255.0
defocused_blurred=defocused_blurred/255.0
motion_blurred=motion_blurred/255.0

In [10]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Encoder
        # In the encoder, convolutional layers with the Conv2d function are used to extract features from the input image. 
        # Each block in the encoder consists of two convolutional layers followed by a max-pooling layer, with the exception of the last block which does not include a max-pooling layer.
        # -------
        # input: 572x572x3
        self.e11 = nn.Conv2d(3, 64, kernel_size=3, padding=1) # output: 570x570x64
        self.e12 = nn.Conv2d(64, 64, kernel_size=3, padding=1) # output: 568x568x64
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2) # output: 284x284x64

        # input: 284x284x64
        self.e21 = nn.Conv2d(64, 128, kernel_size=3, padding=1) # output: 282x282x128
        self.e22 = nn.Conv2d(128, 128, kernel_size=3, padding=1) # output: 280x280x128
        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) # output: 140x140x128

        # input: 140x140x128
        self.e31 = nn.Conv2d(128, 256, kernel_size=3, padding=1) # output: 138x138x256
        self.e32 = nn.Conv2d(256, 256, kernel_size=3, padding=1) # output: 136x136x256
        self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)  # output: 68x68x256

        # input: 68x68x256
        self.e41 = nn.Conv2d(256, 512, kernel_size=3, padding=1) # output: 66x66x512
        self.e42 = nn.Conv2d(512, 512, kernel_size=3, padding=1) # output: 64x64x512
        self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2) # output: 32x32x512

        # input: 32x32x512
        self.e51 = nn.Conv2d(512, 1024, kernel_size=3, padding=1) # output: 30x30x1024
        self.e52 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1) # output: 28x28x1024


        # Decoder
        self.upconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.d11 = nn.Conv2d(1024, 512, kernel_size=3, padding=1)
        self.d12 = nn.Conv2d(512, 512, kernel_size=3, padding=1)

        self.upconv2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.d21 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.d22 = nn.Conv2d(256, 256, kernel_size=3, padding=1)

        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.d31 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.d32 = nn.Conv2d(128, 128, kernel_size=3, padding=1)

        self.upconv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.d41 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.d42 = nn.Conv2d(64, 64, kernel_size=3, padding=1)

        # Output layer
        self.outconv = nn.Conv2d(64, 3, kernel_size=1)

    def forward(self, x):
        # Encoder
        xe11 = relu(self.e11(x))
        xe12 = relu(self.e12(xe11))
        xp1 = self.pool1(xe12)

        xe21 = relu(self.e21(xp1))
        xe22 = relu(self.e22(xe21))
        xp2 = self.pool2(xe22)

        xe31 = relu(self.e31(xp2))
        xe32 = relu(self.e32(xe31))
        xp3 = self.pool3(xe32)

        xe41 = relu(self.e41(xp3))
        xe42 = relu(self.e42(xe41))
        xp4 = self.pool4(xe42)

        xe51 = relu(self.e51(xp4))
        xe52 = relu(self.e52(xe51))
        
        # Decoder
        xu1 = self.upconv1(xe52)
        xu11 = torch.cat([xu1, xe42], dim=1)
        xd11 = relu(self.d11(xu11))
        xd12 = relu(self.d12(xd11))

        xu2 = self.upconv2(xd12)
        xu22 = torch.cat([xu2, xe32], dim=1)
        xd21 = relu(self.d21(xu22))
        xd22 = relu(self.d22(xd21))

        xu3 = self.upconv3(xd22)
        xu33 = torch.cat([xu3, xe22], dim=1)
        xd31 = relu(self.d31(xu33))
        xd32 = relu(self.d32(xd31))

        xu4 = self.upconv4(xd32)
        xu44 = torch.cat([xu4, xe12], dim=1)
        xd41 = relu(self.d41(xu44))
        xd42 = relu(self.d42(xd41))

        # Output layer
        out = self.outconv(xd42)

        return out

In [11]:
class ImageMaskDataset(Dataset):
    def __init__(self, images, masks):
        self.images = images
        self.masks = masks

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]
        return torch.tensor(image, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)

In [12]:
train_images1, val_images, train_masks1, val_masks = train_test_split(defocused_blurred, sharp, test_size=0.2, random_state=42)

In [13]:
train_dataset = ImageMaskDataset(train_images1, train_masks1)
val_dataset = ImageMaskDataset(val_images, val_masks)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

In [14]:
model = UNet()
criterion = nn.MSELoss()                                    #L2 Criterion
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
epochs = 1
for epoch in range(epochs):
    model.train()
    train_loss = 0
    for images, masks in train_loader:
        optimizer.zero_grad()
        outputs = model(images.permute(0, 3, 1, 2))
        loss = criterion(outputs, masks.permute(0, 3, 1, 2))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, masks in val_loader:
            outputs = model(images.permute(0, 3, 1, 2))
            loss = criterion(outputs, masks.permute(0, 3, 1, 2))
            val_loss += loss.item()
    
    val_loss /= len(val_loader)
    
    print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

In [12]:
train_images2, val_images, train_masks2, val_masks = train_test_split(motion_blurred, sharp, test_size=0.2, random_state=42)

In [13]:
train_dataset = ImageMaskDataset(train_images2, train_masks2)
val_dataset = ImageMaskDataset(val_images, val_masks)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)