# UNet using OOPS

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

Defining the double convolution step

In [2]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(DoubleConv, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias = False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias = False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

  def forward(self,x):
      return self.conv(x)

Defining the downsampling + double convolution step

In [3]:
class Down(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(Down, self).__init__()
    self.avgpool_conv = nn.Sequential(
        nn.AvgPool2d(2),
        DoubleConv(in_channels, out_channels)
    )

  def forward(self, x):
      return self.avgpool_conv(x)

Defining the upsampling + double convolution step

In [4]:
class Up(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(Up, self).__init__()
    self.up = nn.ConvTranspose2d(in_channels, in_channels//2, 2, 2)
    self.conv = DoubleConv(in_channels, out_channels)

  def forward(self, x1, skip):
    x1 = self.up(x1)
    diffY = skip.size()[2] - x1.size()[2]
    diffX = skip.size()[3] - x1.size()[3]

    x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
    x = torch.cat([skip, x1], dim=1)
    return self.conv(x)


Defining the out convolution

In [5]:
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return self.sigmoid(self.conv(x))

Defining the model

In [6]:
class UNet(nn.Module):
  def __init__(self, in_channels = 3, out_channels = 3, features = [64,128,256,512]):
    super(UNet, self).__init__()
    self.inc = DoubleConv(in_channels, features[0])
    self.down1 = Down(features[0], features[1])
    self.down2 = Down(features[1], features[2])
    self.down3 = Down(features[2], features[3])
    self.down4 = Down(features[3], features[3]*2) #bridging step
    self.up1 = Up(features[3]*2, features[3])
    self.up2 = Up(features[3], features[2])
    self.up3 = Up(features[2], features[1])
    self.up4 = Up(features[1], features[0])
    self.outc = OutConv(features[0], out_channels)

  def forward(self, x):
    x1 = self.inc(x)
    x2 = self.down1(x1)
    x3 = self.down2(x2)
    x4 = self.down3(x3)
    x5 = self.down4(x4)
    x = self.up1(x5, x4)
    x = self.up2(x, x3)
    x = self.up3(x, x2)
    x = self.up4(x, x1)
    logits = self.outc(x)
    return logits

# Loading Data

In [7]:
!pip install -q kaggle

In [8]:
from google.colab import files
files.upload()

Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"idhantkadela","key":"d60aac77e655a944cb9bdb3890e40621"}'}

In [9]:
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [10]:
!kaggle datasets download -d kwentar/blur-dataset

Dataset URL: https://www.kaggle.com/datasets/kwentar/blur-dataset
License(s): CC0-1.0
Downloading blur-dataset.zip to /content
100% 1.49G/1.49G [00:14<00:00, 136MB/s]
100% 1.49G/1.49G [00:14<00:00, 108MB/s]


In [11]:
!unzip blur-dataset.zip

Archive:  blur-dataset.zip
  inflating: blur_dataset_scaled/defocused_blurred/0_IPHONE-SE_F.JPG  
  inflating: blur_dataset_scaled/defocused_blurred/100_NIKON-D3400-35MM_F.JPG  
  inflating: blur_dataset_scaled/defocused_blurred/101_NIKON-D3400-35MM_F.JPG  
  inflating: blur_dataset_scaled/defocused_blurred/102_NIKON-D3400-35MM_F.JPG  
  inflating: blur_dataset_scaled/defocused_blurred/103_HUAWEI-P20_F.jpg  
  inflating: blur_dataset_scaled/defocused_blurred/104_IPHONE-SE_F.jpg  
  inflating: blur_dataset_scaled/defocused_blurred/105_IPHONE-SE_F.jpg  
  inflating: blur_dataset_scaled/defocused_blurred/106_NIKON-D3400-35MM_F.JPG  
  inflating: blur_dataset_scaled/defocused_blurred/107_XIAOMI-MI8-SE_F.jpg  
  inflating: blur_dataset_scaled/defocused_blurred/108_XIAOMI-MI8-SE_F.jpg  
  inflating: blur_dataset_scaled/defocused_blurred/109_HONOR-7X_F.jpg  
  inflating: blur_dataset_scaled/defocused_blurred/10_ASUS-ZENFONE-LIVE-ZB501KL_F.jpg  
  inflating: blur_dataset_scaled/defocused_blurr

In [12]:
sorted(os.listdir('/content/motion_blurred'))
sorted(os.listdir('/content/sharp'))

['0_IPHONE-SE_S.JPG',
 '100_NIKON-D3400-35MM_S.JPG',
 '101_NIKON-D3400-35MM_S.JPG',
 '102_NIKON-D3400-35MM_S.JPG',
 '103_HUAWEI-P20_S.jpg',
 '104_IPHONE-SE_S.jpg',
 '105_IPHONE-SE_S.jpg',
 '106_NIKON-D3400-35MM_S.JPG',
 '107_XIAOMI-MI8-SE_S.jpg',
 '108_XIAOMI-MI8-SE_S.jpg',
 '109_HONOR-7X_S.jpg',
 '10_ASUS-ZENFONE-LIVE-ZB501KL_S.jpg',
 '110_IPHONE-7_S.jpeg',
 '111_IPHONE-7_S.jpeg',
 '112_NIKON-D3400-35MM_S.JPG',
 '113_SAMSUNG-GALAXY-A5_S.jpg',
 '114_ASUS-ZE500KL_S.jpg',
 '115_NIKON-D3400-35MM_S.JPG',
 '116_BQ-5512L_S.jpg',
 '117_HONOR-7X_S.jpg',
 '118_HONOR-7X_S.jpg',
 '119_HONOR-7X_S.jpg',
 '11_XIAOMI-MI8-SE_S.jpg',
 '120_HONOR-7X_S.jpg',
 '121_HONOR-7X_S.jpg',
 '122_HONOR-7X_S.jpg',
 '123_NIKON-D3400-35MM_S.JPG',
 '124_HONOR-7X_S.jpg',
 '125_NIKON-D3400-35MM_S.JPG',
 '126_NIKON-D3400-18-55MM_S.JPG',
 '127_IPHONE-8_S.jpeg',
 '128_XIAOMI-MI8-SE_S.jpg',
 '129_NIKON-D3400-18-55MM_S.JPG',
 '12_SAMSUNG-GALAXY-J5_S.jpg',
 '130_NIKON-D3400-18-55MM_S.JPG',
 '131_NIKON-D3400-18-55MM_S.JPG',
 '

In [13]:
class UNetDataset(Dataset):
    def __init__(self, dir):
        self.dir = dir
        self.images = sorted(list(self.dir + '/' + file_name for file_name in sorted(os.listdir(self.dir))))
        self.transform = transforms.Compose([
            transforms.Resize((286, 286)),
            transforms.ToTensor()])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img = Image.open(self.images[index]).convert("RGB")
        img = self.transform(img)
        return img

In [14]:
def show_batch(image_batch, title):
    grid = torchvision.utils.make_grid(image_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))
    plt.title(title)
    plt.show()

In [15]:
blurred_dataset = UNetDataset(dir='/content/motion_blurred')
sharp_dataset = UNetDataset(dir='/content/sharp')

# Training

In [16]:
LEARNING_RATE = 0.001
BATCH_SIZE = 4
EPOCHS = 20

In [18]:
indices = np.arange(0, blurred_dataset.__len__())
np.random.shuffle(indices)
indices

array([125, 136, 129,  11, 196,  18, 268,  65, 112,  90,  25, 330, 289,
        57,  85, 164,   1, 228, 173, 134,  76, 321, 317,  16, 160,  55,
       302, 250, 314, 165, 285, 188, 141,  71, 341, 207, 271, 304,  89,
       240, 292, 237, 167, 186,  49, 178, 104, 222, 224, 156,  83, 175,
       100, 328,  82, 276, 171,   3, 251, 266, 194, 309,   6,  60, 193,
        72, 225,  37, 190, 310, 270, 144,  13,   0, 247, 298, 110, 174,
       259,  48,  87, 177,  32, 301, 263, 345, 329, 267,  77, 342, 308,
       307, 332,  46, 305, 120, 283, 262,  59, 294, 287, 152, 154,  44,
       206, 234, 258, 246, 335, 346, 149,  50,  96,  62, 148, 299, 135,
       212, 284, 123, 105, 180,  20, 150,  23, 277, 288, 214, 189,  74,
       296, 102, 179, 320, 151, 311, 208, 227, 319, 101, 109,  88, 145,
        41, 336, 338, 138, 264, 333, 159, 114,  61, 163, 147, 220, 326,
        51,  30, 219, 306,  39, 273, 218, 126, 323,   8, 257,  67,  69,
        98,  92, 139, 142, 157, 322, 241, 168,  19, 127,  22, 20