In [7]:
import numpy as np
import os
import matplotlib.pyplot as plt
import glob
import cv2
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import argparse
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torchvision.utils import save_image
from sklearn.model_selection import train_test_split

In [8]:
parser = argparse.ArgumentParser()
parser.add_argument('-e', '--epochs', type=int, default=20, help='number of epochs to train the model for')
# args = vars(parser.parse_args())
args = parser.parse_args(args=[])
print(args.epochs)
# args['epochs']
# def save_decoded_image(img, name):
#     img = img.view(img.size(0), 3, 224, 224)
#     save_image(img, name)

In [9]:
# helper functions
image_dir = '../output/kaggle/working/output_images'
os.makedirs(image_dir, exist_ok=True)

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

print(device)

batch_size = 2

In [10]:
gauss_blur = os.listdir('../input/cv-images/blurred-20220427T190930Z-001/blurred')
gauss_blur.sort()
sharp = os.listdir('../input/cv-images/sharp-20220427T190932Z-001/sharp')
sharp.sort()
x_blur = []
for i in range(len(gauss_blur)):
    x_blur.append(gauss_blur[i])
#     x_blur.append(cv2.imread(f'../input/cv-dataset/resized_blur_2-20220427T164258Z-001/resized_blur_2/{gauss_blur[i]}'))
y_sharp = []
for i in range(len(sharp)):
    y_sharp.append(sharp[i])

In [11]:
for i, img in tqdm(enumerate(sharp), total = len(sharp)):
    img = cv2.imread(f"../input/cv-images/blurred-20220427T190930Z-001/blurred/{sharp[i]}", cv2.IMREAD_COLOR)
    # add gaussian blurring
    blur = cv2.GaussianBlur(img, (31, 31), 0)
    cv2.imwrite(f"./gaussian_blur/{sharp[i]}", blur)

In [12]:
(x_train, x_val, y_train, y_val) = train_test_split(x_blur, y_sharp, test_size=0.25)
print(f"Train data instances: {len(x_train)}")
print(f"Validation data instances: {len(x_val)}")

In [13]:
transform = transforms.Compose([transforms.ToPILImage(), transforms.Resize((224, 224)), transforms.ToTensor()])

def change(batch):
    images = []
    for I in batch:
        img = I.cpu().data.numpy()
        r = img[0]
        g = img[1]
        b = img[2]
    #     new_img = np.concatenate((r, g), axis = 2)
    #     new_img = np.concatenate((new_img, b), axis = 2)
        new_img = cv2.merge((r, g, b))
        images.append(new_img)
#     print(new_img.shape)
    return np.array(images)

def change_dim(I):
    img = I.cpu().data.numpy()
    r = img[0]
    g = img[1]
    b = img[2]
    new_img = cv2.merge((r, g, b))
    return new_img

In [14]:
class DeblurDataset(Dataset):
    def __init__(self, blur_paths, sharp_paths = None, transforms = None):
        self.X = blur_paths
        self.y = sharp_paths
        self.transforms = transforms
         
    def __len__(self):
        return (len(self.X))
    
    def __getitem__(self, i):
#         blur_image = cv2.imread(f"../input/cv-dataset/resized_blur_2-20220427T164258Z-001/resized_blur_2/{self.X[i]}")
#         print(self.X[i])
        blur_image = cv2.imread(f'../input/cv-data/resized_blur_2-20220427T164258Z-001/resized_blur_2/{self.X[i]}')
        
        if self.transforms:
            blur_image = self.transforms(blur_image)
#             print(True)
        
#         else:
#             print(False)
            
        if self.y is not None:
            sharp_image = cv2.imread(f"../input/cv-data/resized_sharp_2-20220427T164449Z-001/resized_sharp_2/{self.y[i]}")
            sharp_image = self.transforms(sharp_image)
            return (blur_image, sharp_image)
        else:
            return blur_image

In [15]:
train_data = DeblurDataset(x_train, y_train, transform)
val_data = DeblurDataset(x_val, y_val, transform)
 
trainloader = DataLoader(train_data, batch_size = batch_size, shuffle = False)
valloader = DataLoader(val_data, batch_size = batch_size, shuffle = False)

In [16]:
print(type(train_data[0]))
print(train_data[0][0].shape)
plt.imshow(change_dim(train_data[0][0]))
plt.show()

In [17]:
class DeblurCNN(nn.Module):
    def __init__(self):
        super(DeblurCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=2)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=2)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return x
model = DeblurCNN().to(device)
print(model)

In [18]:
criterion = nn.MSELoss()
# the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5, verbose = True)

In [19]:
def fit(model, dataloader, epoch):
    model.train()
    running_loss = 0.0
    for i, data in tqdm(enumerate(dataloader), total = int(len(train_data) / dataloader.batch_size)):
        blur_image = data[0]
        sharp_image = data[1]
        blur_image = blur_image.to(device)
        sharp_image = sharp_image.to(device)
        optimizer.zero_grad()
        outputs = model(blur_image)
        loss = criterion(outputs, sharp_image)
        # backpropagation
        loss.backward()
        # update the parameters
        optimizer.step()
        running_loss += loss.item()
    
    train_loss = running_loss/len(dataloader.dataset)
    print(f"Train Loss: {train_loss:.5f}")
    
    return train_loss

In [20]:
def validate(model, dataloader, epoch):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloader), total=int(len(val_data)/dataloader.batch_size)):
            blur_image = data[0]
            sharp_image = data[1]
            blur_image = blur_image.to(device)
            sharp_image = sharp_image.to(device)
            outputs = model(blur_image)
#             print(type(sharp_image))
#             print(type(blur_image))
#             print(type(outputs))
            loss = criterion(outputs, sharp_image)
            running_loss += loss.item()
            if epoch == 0 and i == int((len(val_data)/dataloader.batch_size)-1):
#                 save_decoded_image(sharp_image.cpu().data, name=f"../outputs/saved_images/sharp{epoch}.jpg")
#                 save_decoded_image(blur_image.cpu().data, name=f"../outputs/saved_images/blur{epoch}.jpg")
#                 print(sharp_image[0].shape)
                
                si = change(sharp_image)
                bi = change(blur_image)
                for i in range(si.shape[0]):
                    plt.imshow(si[i])
                    plt.title("sharp")
                    plt.show()
                    plt.imshow(bi[i])
                    plt.title("blur")
                    plt.show()
#                 plt.imshow(si.cpu().data.numpy())
            if i == int((len(val_data)/dataloader.batch_size)-1):
#                 save_decoded_image(outputs.cpu().data, name=f"../outputs/saved_images/val_deblurred{epoch}.jpg")
                oi = change(outputs)
                for i in range(oi.shape[0]):
                    plt.imshow(oi[i])
                    plt.title("output")
#                     plt.imshow(bi[i])
                    plt.show()
        val_loss = running_loss/len(dataloader.dataset)
        print(f"Val Loss: {val_loss:.5f}")
        
        return val_loss

In [21]:
train_loss  = []
val_loss = []
start = time.time()
for epoch in range(args.epochs):
    print(f"Epoch {epoch+1} of {args.epochs}")
    train_epoch_loss = fit(model, trainloader, epoch)
    val_epoch_loss = validate(model, valloader, epoch)
    train_loss.append(train_epoch_loss)
    val_loss.append(val_epoch_loss)
    scheduler.step(val_epoch_loss)
end = time.time()
print(f"Took {((end-start)/60):.3f} minutes to train")

In [22]:
plt.figure(figsize=(10, 7))
plt.plot(train_loss, color='orange', label='train loss')
plt.plot(val_loss, color='red', label='validataion loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
# plt.savefig('../outputs/loss.png')
plt.show()